aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2022-05-04 16:20:43 +0100
committerTim Hall <tim.hall@arm.com>2022-05-04 16:26:09 +0100
commitb5df773e92051004158046b0ed2c7b802198de6e (patch)
tree7d738a28a63b66a20f379acbdfc6c3e7c4a98a61
parent95b07c1c0fed6a985607131e59a593786d40b389 (diff)
downloadethos-u-vela-b5df773e92051004158046b0ed2c7b802198de6e.tar.gz
Revert "MLBEDSW-6263: Use separate tensors for double buffering"
This reverts commit cc5f4de1c35ba44fca7ff6295c6ae846f8242344. Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: I0fa5babfe9ad9ec668720d04fe1c16d9a9092131
-rw-r--r--ethosu/vela/cascade_builder.py12
-rw-r--r--ethosu/vela/high_level_command_stream_generator.py9
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py38
-rw-r--r--ethosu/vela/live_range.py23
-rw-r--r--ethosu/vela/npu_performance.py14
-rw-r--r--ethosu/vela/scheduler.py59
-rw-r--r--ethosu/vela/weight_compressor.py15
7 files changed, 75 insertions, 95 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 0d25ec6..4c3f75b 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -144,8 +144,10 @@ class CascadeBuilder:
# Keep track of which Ops are in the proposed cascade as well as the best cascade so far
ops_in_cascade = [op]
ops_in_best_cascade = [op]
- # Get the size of the weight buffer(s)
- weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
+ # Get the size of the weight buffer
+ weight_buffer = 0
+ if ref_cost[op].buffered_weight_tensor:
+ weight_buffer = ref_cost[op].buffered_weight_tensor.storage_size()
# The first IFM needs to be stored in full
cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
@@ -188,8 +190,10 @@ class CascadeBuilder:
op_full_ofm = current_op.ofm_size_in_bytes()
_, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
- # Get the size of the weight buffer(s)
- op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
+ # Get the size of the weight buffer
+ op_weight_buffer = 0
+ if ref_cost[current_op].buffered_weight_tensor:
+ op_weight_buffer = ref_cost[current_op].buffered_weight_tensor.storage_size()
# Calculate the uncascaded memory requirement for current Op
uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 81c0d5b..136f5a9 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -204,12 +204,9 @@ def generate_high_level_commands_for_sched_op(sched_op, schedule):
if op_info.npu_weights_tensor:
weight_box = Box([0, 0, 0, start_channel], [1, 1, 1, end_channel])
- if op_info.buffered_weight_tensors and is_first_h_stripe:
- idx = depth_idx % len(op_info.buffered_weight_tensors)
- yield from dma_if_necessary(
- sched_op.parent_ps, weight_box, op_info.buffered_weight_tensors[idx]
- )
- weight_tensor = op_info.buffered_weight_tensors[idx]
+ if op_info.buffered_weight_tensor and is_first_h_stripe:
+ yield from dma_if_necessary(sched_op.parent_ps, weight_box, op_info.buffered_weight_tensor)
+ weight_tensor = op_info.buffered_weight_tensor
else:
weight_box = None
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index e6bfc1c..3a78d6f 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -68,6 +68,7 @@ from .tensor import MemType
from .tensor import Tensor
from .tensor import TensorFormat
from .tensor import TensorPurpose
+from .tensor import TensorSubPurpose
from .weight_compressor import NpuWeightTensor
from .weight_compressor import WeightKey
@@ -201,15 +202,9 @@ def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
return mem_limits
-def get_upscale(op: Operation) -> NpuResamplingMode:
- upscale = NpuResamplingMode.NONE
- if op.type == Op.ResizeBilinear:
- # perform nearest neighbor upscale
- upscale = NpuResamplingMode.NEAREST
- elif op.type == Op.Conv2DBackpropInputSwitchedBias:
- # perform insert zero upscale
- upscale = NpuResamplingMode.TRANSPOSE
- return upscale
+def get_double_buffer_offset(arch: ArchitectureFeatures, range_index: int, core: int) -> int:
+ """Returns 0 if the first half of a double buffer should be used, 1 if the second half should be used"""
+ return ((range_index - core) // arch.ncores) % 2
def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
@@ -319,13 +314,20 @@ def create_weights(
key = WeightKey(core, weight_box.start_coord[-1])
if key in w_tensor_src.encoded_ranges:
weight_range = w_tensor_src.encoded_ranges[key]
- if weight_tensor == w_tensor_src:
- # Straight from source tensor
- address = weight_tensor.address + weight_range.offset
- else:
- # Weight buffered tensor
+ if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
+ assert weight_tensor != w_tensor_src
+ # Double buffered inside weight_tensor
address = weight_tensor.address + core_offset
+ address += get_double_buffer_offset(arch, weight_range.index, core) * w_tensor_src.max_range_bytes
core_offset += round_up(weight_range.total_bytes, 16)
+ else:
+ if weight_tensor == w_tensor_src:
+ # Straight from source tensor
+ address = weight_tensor.address + weight_range.offset
+ else:
+ # Single buffered inside weight tensor
+ address = weight_tensor.address + core_offset
+ core_offset += round_up(weight_range.total_bytes, 16)
# Location of weights in tensor
addr_range = NpuAddressRange(
@@ -524,7 +526,13 @@ def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
if core == 0:
weight_range = cmd.in_tensor.encoded_ranges[key]
src_addr = cmd.in_tensor.address + weight_range.offset
- dest_addr = cmd.out_tensor.address
+
+ if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
+ dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
+ get_double_buffer_offset(arch, weight_range.index, core)
+ )
+ else:
+ dest_addr = cmd.out_tensor.address
else:
start_coord = cmd.box.start_coord
src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index ccf4929..19d0c11 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -63,7 +63,7 @@ class LiveRange:
def mark_usage(self, op_time, op_length=1):
op_time_start = max(op_time, 0)
op_time_end = op_time + op_length
- if op_time_end < op_time_start:
+ if op_time_end <= op_time_start:
return
self.start_time = min(self.start_time, op_time_start)
@@ -325,20 +325,13 @@ def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set,
rng.mark_usage(time_to_set)
- for idx, weight_tens in enumerate(op_info.buffered_weight_tensors):
- if weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
- rng = lr_graph.get_or_create_range(weight_tens)
- start_time = time_to_set
- length = 1
- if weight_tens.pre_buffer:
- start_time -= 1
- length += 1
- if len(op_info.buffered_weight_tensors) > 1:
- last_idx = len(op_info.ofm_depth_slices) % len(op_info.buffered_weight_tensors)
- # Double buffering: reduce end time of the buffer that is not used last
- if last_idx != idx:
- length -= 1
- rng.mark_usage(start_time, length)
+ weight_tens = op_info.buffered_weight_tensor
+ if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
+ rng = lr_graph.get_or_create_range(weight_tens)
+ if weight_tens.pre_buffer:
+ rng.mark_usage(time_to_set - 1, 2)
+ else:
+ rng.mark_usage(time_to_set)
if time_to_set == lr_graph.current_time:
lr_graph.current_time += 2
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 0c8a907..81d0be7 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -620,8 +620,8 @@ def estimate_full_op_performance(
prev_cost = schedule.cost_map[prev_op] if prev_op else None
if op.parent_op.bias:
query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
- if cost.buffered_weight_tensors:
- query.const_memory_area = cost.buffered_weight_tensors[0].mem_area
+ if cost.buffered_weight_tensor:
+ query.const_memory_area = cost.buffered_weight_tensor.mem_area
else:
query.const_memory_area = cost.npu_weights_tensor.mem_area
@@ -649,7 +649,7 @@ def estimate_full_op_performance(
# LUT read from SHRAM TODO remove?
scaled_bws[lut_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw
- if cost.npu_weights_tensor and cost.buffered_weight_tensors:
+ if cost.npu_weights_tensor and cost.buffered_weight_tensor:
# DMA Weight Transfer
sz = 0
# Get the size of the first DMA
@@ -661,10 +661,10 @@ def estimate_full_op_performance(
total_sz = len(cost.npu_weights_tensor.buffer)
bws[cost.npu_weights_tensor.mem_area][TensorPurpose.Weights][BandwidthDirection.Read] += total_sz
- bws[cost.buffered_weight_tensors[0].mem_area][TensorPurpose.Weights][BandwidthDirection.Write] += total_sz
+ bws[cost.buffered_weight_tensor.mem_area][TensorPurpose.Weights][BandwidthDirection.Write] += total_sz
ws_first_transfer_cycles = measure_mem2mem_cycles(
- arch, cost.npu_weights_tensor.mem_area, cost.buffered_weight_tensors[0].mem_area, sz
+ arch, cost.npu_weights_tensor.mem_area, cost.buffered_weight_tensor.mem_area, sz
)
# Add cycles for Weight + Scale Transfer
@@ -720,7 +720,7 @@ def estimate_full_op_performance(
bw = access.const_read[0] * bandwidth_compression_scale_approx
bws[query.const_memory_area][TensorPurpose.Weights][BandwidthDirection.Read] += bw
- if not cost.buffered_weight_tensors:
+ if not cost.buffered_weight_tensor:
scaled_bws[query.const_memory_area][TensorPurpose.Weights][BandwidthDirection.Read] += bw
if access.const_read[1] > 0:
@@ -728,7 +728,7 @@ def estimate_full_op_performance(
bw = access.const_read[1] * op.parent_op.bias.element_size()
bws[query.const_memory_area][TensorPurpose.FSBias][BandwidthDirection.Read] += bw
- if not cost.buffered_weight_tensors:
+ if not cost.buffered_weight_tensor:
scaled_bws[query.const_memory_area][TensorPurpose.FSBias][BandwidthDirection.Read] += bw
update_summary_cycles(arch, scaled_bws, cycles_a)
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index dde51c0..e73a26d 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -106,7 +106,7 @@ class SchedulerOpInfo:
self.ofm_depth_slices: List[int] = [0, stripe.depth]
self.npu_weights_tensor: Optional[NpuWeightTensor] = None
self.npu_scales_tensor: Optional[NpuWeightTensor] = None
- self.buffered_weight_tensors: List[Tensor] = []
+ self.buffered_weight_tensor: Optional[Tensor] = None
self.cycles: Optional[CycleCost] = None
self.slack_buffering_cycles = 0
self.slack_buffering_memory = 0
@@ -130,8 +130,9 @@ class SchedulerOpInfo:
res += f"\t\tIFM2 Stripe = {self.stripe_input2}\n"
res += f"\t\tOFM Stripe = {self.stripe}\n"
res += f"\t\tEncoded Weights = {self.npu_weights_tensor and len(self.npu_weights_tensor.buffer)} bytes\n"
- for idx, tens in enumerate(self.buffered_weight_tensors):
- res += f"\t\tWeight buffer{idx + 1} = {tens.storage_size()} bytes\n"
+ res += (
+ f"\t\tWeight buffer = {self.buffered_weight_tensor and self.buffered_weight_tensor.storage_size()} bytes\n"
+ )
res += f"\t\tDepth slices = {self.ofm_depth_slices}\n"
res += f"\t\tAssigned Cascade = {self.cascade}"
return res
@@ -719,7 +720,7 @@ class Scheduler:
# Chosen buffering might not fit at all, iterate until it does
# or until the minimum usable slice size is reached
if (
- encoded_weights.double_buffer_size() <= buffer_limit_bytes
+ encoded_weights.max_range_bytes <= half_buffer_limit
or prebuffer_depth == ArchitectureFeatures.OFMSplitDepth
):
break
@@ -736,40 +737,24 @@ class Scheduler:
cost.slack_buffering_cycles = tail_cycles.op_cycles
# Determine whether the weights need to be double buffered
- weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes())
+ weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes)
# Only buffer weights if there's still space left for the buffer
if weight_buffer_size <= buffer_limit_bytes:
assert weight_buffer_size % 16 == 0
# Determine whether to double buffer or single buffer
- double_buffer_size = encoded_weights.double_buffer_size()
- if (double_buffer_size <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
+ if (weight_buffer_size * 2 <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
+ weight_buffer_size = weight_buffer_size * 2
weight_tensor_purpose = TensorSubPurpose.DoubleBuffer
else:
weight_tensor_purpose = TensorSubPurpose.Standard
- cost.buffered_weight_tensors = [
- self.buffer_tensor(
- encoded_weights,
- weight_tensor_purpose,
- encoded_weights.double_buffer_sizes[0],
- weight_tensor.name + "_buffer",
- )
- ]
- if weight_tensor_purpose == TensorSubPurpose.DoubleBuffer:
- buf2 = self.buffer_tensor(
- encoded_weights,
- weight_tensor_purpose,
- encoded_weights.double_buffer_sizes[1],
- weight_tensor.name + "_buffer2",
- )
- cost.buffered_weight_tensors.append(buf2)
- last_used_buffer_idx = len(cost.ofm_depth_slices) % 2
- weight_buffer_size = encoded_weights.double_buffer_sizes[last_used_buffer_idx]
+ cost.buffered_weight_tensor = self.buffer_tensor(
+ encoded_weights, weight_tensor_purpose, weight_buffer_size, weight_tensor.name
+ )
if ref_cost.cascade == 0:
- # Determine if the lifetime can be extended and pre-buffer the first weight buffer
- # under the previous operation
- cost.buffered_weight_tensors[0].pre_buffer = encoded_weights.double_buffer_sizes[0] < slack_memory
+ # Determine if the lifetime can be extended and pre-buffer weights under the previous operation
+ cost.buffered_weight_tensor.pre_buffer = weight_buffer_size < slack_memory
cost.slack_buffering_memory -= weight_buffer_size
else:
@@ -782,7 +767,7 @@ class Scheduler:
cost.npu_scales_tensor = encoded_scales
def buffer_tensor(self, src_tensor: Tensor, sub_purpose: TensorSubPurpose, buffer_size: int, name: str) -> Tensor:
- buffered_weight_tensor = Tensor([1, 1, 1, buffer_size], DataType.uint8, name)
+ buffered_weight_tensor = Tensor([1, 1, 1, buffer_size], DataType.uint8, name + "_buffer")
buffered_weight_tensor.src_tensor = src_tensor
buffered_weight_tensor.mem_area = self.arch.fast_storage_mem_area
buffered_weight_tensor.mem_type = MemType.Scratch_fast
@@ -824,13 +809,11 @@ class Scheduler:
# Create a cost entry with the new stripe
cost = sched_op.create_scheduler_info(self.nng, stripe)
- for buffered_tens in ref_cost[sched_op].buffered_weight_tensors:
+ if ref_cost[sched_op].buffered_weight_tensor:
# If the weights are buffered in the reference schedule they should be in the new proposal
weight_tensor = cost.npu_weights_tensor
- cost.buffered_weight_tensors.append(
- self.buffer_tensor(
- weight_tensor, TensorSubPurpose.Standard, buffered_tens.storage_size(), buffered_tens.name
- )
+ cost.buffered_weight_tensor = self.buffer_tensor(
+ weight_tensor, TensorSubPurpose.Standard, len(weight_tensor.buffer), weight_tensor.name
)
# Estimate performance
@@ -859,7 +842,9 @@ class Scheduler:
peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage)
else:
# This Op is not part of a cascade - calculate the memory usage
- op_weight_buffer = sum(tens.storage_size() for tens in cost[sched_op].buffered_weight_tensors)
+ op_weight_buffer = 0
+ if cost[sched_op].buffered_weight_tensor:
+ op_weight_buffer = cost[sched_op].buffered_weight_tensor.storage_size()
op_mem_usage = (
sched_op.ifm_size_in_bytes()
@@ -998,8 +983,8 @@ class Scheduler:
sched_op.parent_ps.block_config = op_info.block_config.old_style_representation()
# Ensure that the src_tensor reference is set correctly
- for tens in op_info.buffered_weight_tensors:
- tens.src_tensor = op_info.npu_weights_tensor
+ if op_info.buffered_weight_tensor:
+ op_info.buffered_weight_tensor.src_tensor = op_info.npu_weights_tensor
def use_fast_storage_for_feature_maps(self, schedule, staging_limit):
scratched_fms = {}
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 78c4351..86b424a 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -68,19 +68,12 @@ class NpuWeightTensor(Tensor):
def __init__(self, name):
Tensor.__init__(self, None, None, name + "_npu_encoded_weights")
self.buffer = []
- self.double_buffer_sizes = [0, 0] # Required sizes if double buffering is used
+ self.max_range_bytes = 0
self.encoded_ranges = OrderedDict()
self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
self.dtype = DataType.uint8
self.scale_compression_config = None
- def max_range_bytes(self):
- return max(self.double_buffer_sizes)
-
- def double_buffer_size(self):
- """Return total required size for double buffering"""
- return sum(self.double_buffer_sizes)
-
class CompressedWeightCache:
"""Global tensor weight compression cache"""
@@ -364,7 +357,7 @@ def encode_weight_and_scale_tensor(
weights = np.flip(weights, axis=(0, 1))
encoded_stream = bytearray()
- double_buffer_sizes = [0, 0]
+ max_single_buffer_len = 0
is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise
# Bias & scale
@@ -442,11 +435,11 @@ def encode_weight_and_scale_tensor(
npu_tensor.encoded_ranges[key] = weight_range
# Remember maximum encoded length for DoubleBuffering
- double_buffer_sizes[idx % 2] = max(double_buffer_sizes[idx % 2], len(encoded_stream) - buffer_start_offset)
+ max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream) - buffer_start_offset)
# Attach buffer to tensor
npu_tensor.buffer = encoded_stream
- npu_tensor.double_buffer_sizes = double_buffer_sizes
+ npu_tensor.max_range_bytes = max_single_buffer_len
npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)])
npu_tensor.format = TensorFormat.WeightsCompressed