path: root/ethosu/vela/scheduler.py
diff options
Diffstat (limited to 'ethosu/vela/scheduler.py')
1 files changed, 129 insertions, 43 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 24453d8c..5c2ddabb 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -959,52 +959,66 @@ class DynamicProgrammingScheduler:
self.sg.cascaded_passes = cascaded_passes
- if self.options.use_nhcwb16_between_cascaded_passes:
- # Check if NHCWB16 can be used in between cascaded passes
- # (NHCWB16 within cascaded passes has been handled earlier in this function)
- if self.sg.placement == PassPlacement.Npu:
- last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
- for ps in self.sg.cascaded_passes:
- if ps.placement != PassPlacement.Npu:
+ # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
+ # (NHCWB16 within cascaded passes has been handled earlier in this function)
+ if self.sg.placement == PassPlacement.Npu:
+ # Dictionary tensor -> list of ops, containing feature maps that can be attempted
+ # to be moved to fast storage
+ fast_storage_tensor_rewrites = {}
+ last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
+ for ps in self.sg.cascaded_passes:
+ if ps.placement != PassPlacement.Npu:
+ continue
+ for output in ps.outputs:
+ if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
- for output in ps.outputs:
- if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
- continue
- use_NHCWB16 = True
- rewrites = []
- for op in output.consumer_list:
- if op is None or (op.type == "ReduceSum" and output.dtype == DataType.int32):
- use_NHCWB16 = False
- elif op.type == "Reshape":
- # Detect no-op reshapes by comparing their full input and output tensor shapes.
- inshape = full_shape(4, op.inputs[0].shape, 1)
- outshape = full_shape(4, op.outputs[0].shape, 1)
- # Using NHCWB16 format for a no-op reshape is only an option if subsequent
- # consumers do not also need to perform a reshape or if the OFM is going to
- # be processed by CPU operations. No-op reshape consumers with empty lists
- # (those that have no consumers, or null-consumers used as list terminators)
- # must use normal NHWC output.
- incompatible_consumers = [
- (
- not consumer.run_on_npu
- or consumer.type == "Reshape"
- or (consumer is last_op_in_subgraph)
- )
- for consumer in op.outputs[0].consumer_list
- if consumer is not None
- ]
- if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
- rewrites.append(op)
- else:
- use_NHCWB16 = False
+ use_NHCWB16 = True
+ use_fast_storage = True
+ rewrites = []
+ for op in output.consumer_list:
+ if op is None:
+ use_NHCWB16 = False
+ use_fast_storage = False
+ continue
+ if op.type == "ReduceSum" and output.dtype == DataType.int32:
+ use_NHCWB16 = False
+ elif op.type == "Reshape":
+ # Detect no-op reshapes by comparing their full input and output tensor shapes.
+ inshape = full_shape(4, op.inputs[0].shape, 1)
+ outshape = full_shape(4, op.outputs[0].shape, 1)
+ # Using NHCWB16 format for a no-op reshape is only an option if subsequent
+ # consumers do not also need to perform a reshape or if the OFM is going to
+ # be processed by CPU operations. No-op reshape consumers with empty lists
+ # (those that have no consumers, or null-consumers used as list terminators)
+ # must use normal NHWC output.
+ incompatible_consumers = [
+ (
+ not consumer.run_on_npu
+ or consumer.type == "Reshape"
+ or (consumer is last_op_in_subgraph)
+ )
+ for consumer in op.outputs[0].consumer_list
+ if consumer is not None
+ ]
+ if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
+ rewrites.append(op)
- use_NHCWB16 &= op.run_on_npu
- if use_NHCWB16:
- output.set_format(TensorFormat.NHCWB16, arch)
- for rewrite_op in rewrites:
- rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
+ use_NHCWB16 = False
+ use_fast_storage = False
+ use_NHCWB16 &= op.run_on_npu
+ use_fast_storage &= op.run_on_npu
+ if use_fast_storage:
+ fast_storage_tensor_rewrites[output] = rewrites
+ if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
+ output.set_format(TensorFormat.NHCWB16, arch)
+ for rewrite_op in rewrites:
+ rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
+ if self.feature_maps_not_in_fast_storage:
+ # Remember feature maps that can be moved to fast storage for later use
+ # in use_fast_storage_for_feature_maps
+ self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
def schedule_passes(nng, arch, options: SchedulerOptions):
@@ -1027,3 +1041,75 @@ def schedule_passes(nng, arch, options: SchedulerOptions):
if options.verbose_schedule:
+def _calc_tens_to_cps(sg, tensor_rewrites):
+ # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
+ # Returns dictionary tensor -> list of cascaded passes
+ # Note: if cascaded passes are A, B, C, D, and a tensor is output
+ # of A and input to D, then it also consumes SRAM in passes B and C.
+ if "tens_to_cps" in sg.scheduling_info:
+ return sg.scheduling_info["tens_to_cps"]
+ # Determine life-time of tensors
+ min_index = {}
+ max_index = {}
+ index = 0
+ cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
+ for cps in cps_list:
+ for tens in cps.inputs + cps.outputs:
+ if tens in tensor_rewrites:
+ min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
+ max_index[tens] = index
+ index += 1
+ # Convert to affected cps-es
+ tens_to_cps = {}
+ for tens in min_index:
+ tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
+ sg.scheduling_info["tens_to_cps"] = tens_to_cps
+ return tens_to_cps
+def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
+ # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
+ tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
+ tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
+ # Sort tensors first on life-time (smallest first), then on size (biggest first)
+ tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
+ for _, _, _, tens in tens_list:
+ cps_list = tens_to_cps[tens]
+ if len(cps_list) <= 1:
+ continue
+ sz = tens.storage_size()
+ fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
+ if fits_in_fast_storage:
+ tens.mem_area = arch.fast_storage_mem_area
+ tens.mem_type = MemType.Scratch_fast
+ tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
+ assert tens in tensor_rewrites
+ # Also rewrite reshapes
+ for rewrite_op in tensor_rewrites[tens]:
+ tens2 = rewrite_op.outputs[0]
+ tens2.mem_area = arch.fast_storage_mem_area
+ tens2.mem_type = MemType.Scratch_fast
+ tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
+ for cps in cps_list:
+ cps.sram_used += sz
+def undo_use_fast_storage(sg, arch):
+ # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
+ tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
+ tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
+ mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
+ for tens, cps_list in tens_to_cps.items():
+ if tens.mem_type == MemType.Scratch_fast:
+ sz = tens.storage_size()
+ tens.mem_area = mem_area
+ tens.mem_type = MemType.Scratch
+ # Also undo reshapes
+ for rewrite_op in tensor_rewrites[tens]:
+ tens2 = rewrite_op.outputs[0]
+ tens2.mem_area = mem_area
+ tens2.mem_type = MemType.Scratch
+ for cps in cps_list:
+ cps.sram_used -= sz