aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/scheduler.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-09-15 14:05:38 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-09-25 08:32:55 +0200
commit0b9c9a3873da3d368e184308f4f9a4c202e3fb67 (patch)
treef057bfeacd6c6323cd0e7f33cd7ec33f941b4f8d /ethosu/vela/scheduler.py
parent8854dc9088586e1bb0bf2640b2289903dfa3c822 (diff)
downloadethos-u-vela-0b9c9a3873da3d368e184308f4f9a4c202e3fb67.tar.gz
MLBEDSW-2337: Intermediate feature maps in fast storage
Attempts to use fast storage for feature maps used in between cascaded passes. This is only relevant for system configurations where feature maps are by default not placed in SRAM, but there is SRAM for fast storage. Change-Id: I207b7cf32cfcb5bea3e6b93c2da1161c4af5221d Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r--ethosu/vela/scheduler.py172
1 files changed, 129 insertions, 43 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 24453d8c..5c2ddabb 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -959,52 +959,66 @@ class DynamicProgrammingScheduler:
self.sg.cascaded_passes = cascaded_passes
self.sg.build_cascaded_pass_links()
- if self.options.use_nhcwb16_between_cascaded_passes:
- # Check if NHCWB16 can be used in between cascaded passes
- # (NHCWB16 within cascaded passes has been handled earlier in this function)
- if self.sg.placement == PassPlacement.Npu:
- last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
- for ps in self.sg.cascaded_passes:
- if ps.placement != PassPlacement.Npu:
+ # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
+ # (NHCWB16 within cascaded passes has been handled earlier in this function)
+ if self.sg.placement == PassPlacement.Npu:
+ # Dictionary tensor -> list of ops, containing feature maps that can be attempted
+ # to be moved to fast storage
+ fast_storage_tensor_rewrites = {}
+ last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
+ for ps in self.sg.cascaded_passes:
+ if ps.placement != PassPlacement.Npu:
+ continue
+ for output in ps.outputs:
+ if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
continue
- for output in ps.outputs:
- if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
- continue
- use_NHCWB16 = True
- rewrites = []
- for op in output.consumer_list:
- if op is None or (op.type == "ReduceSum" and output.dtype == DataType.int32):
- use_NHCWB16 = False
- elif op.type == "Reshape":
- # Detect no-op reshapes by comparing their full input and output tensor shapes.
- inshape = full_shape(4, op.inputs[0].shape, 1)
- outshape = full_shape(4, op.outputs[0].shape, 1)
- # Using NHCWB16 format for a no-op reshape is only an option if subsequent
- # consumers do not also need to perform a reshape or if the OFM is going to
- # be processed by CPU operations. No-op reshape consumers with empty lists
- # (those that have no consumers, or null-consumers used as list terminators)
- # must use normal NHWC output.
- incompatible_consumers = [
- (
- not consumer.run_on_npu
- or consumer.type == "Reshape"
- or (consumer is last_op_in_subgraph)
- )
- for consumer in op.outputs[0].consumer_list
- if consumer is not None
- ]
- if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
- rewrites.append(op)
- else:
- use_NHCWB16 = False
+ use_NHCWB16 = True
+ use_fast_storage = True
+ rewrites = []
+ for op in output.consumer_list:
+ if op is None:
+ use_NHCWB16 = False
+ use_fast_storage = False
+ continue
+ if op.type == "ReduceSum" and output.dtype == DataType.int32:
+ use_NHCWB16 = False
+ elif op.type == "Reshape":
+ # Detect no-op reshapes by comparing their full input and output tensor shapes.
+ inshape = full_shape(4, op.inputs[0].shape, 1)
+ outshape = full_shape(4, op.outputs[0].shape, 1)
+ # Using NHCWB16 format for a no-op reshape is only an option if subsequent
+ # consumers do not also need to perform a reshape or if the OFM is going to
+ # be processed by CPU operations. No-op reshape consumers with empty lists
+ # (those that have no consumers, or null-consumers used as list terminators)
+ # must use normal NHWC output.
+ incompatible_consumers = [
+ (
+ not consumer.run_on_npu
+ or consumer.type == "Reshape"
+ or (consumer is last_op_in_subgraph)
+ )
+ for consumer in op.outputs[0].consumer_list
+ if consumer is not None
+ ]
+ if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
+ rewrites.append(op)
else:
- use_NHCWB16 &= op.run_on_npu
-
- if use_NHCWB16:
- output.set_format(TensorFormat.NHCWB16, arch)
- for rewrite_op in rewrites:
- rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
+ use_NHCWB16 = False
+ use_fast_storage = False
+ use_NHCWB16 &= op.run_on_npu
+ use_fast_storage &= op.run_on_npu
+
+ if use_fast_storage:
+ fast_storage_tensor_rewrites[output] = rewrites
+ if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
+ output.set_format(TensorFormat.NHCWB16, arch)
+ for rewrite_op in rewrites:
+ rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
+ if self.feature_maps_not_in_fast_storage:
+ # Remember feature maps that can be moved to fast storage for later use
+ # in use_fast_storage_for_feature_maps
+ self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
def schedule_passes(nng, arch, options: SchedulerOptions):
@@ -1027,3 +1041,75 @@ def schedule_passes(nng, arch, options: SchedulerOptions):
if options.verbose_schedule:
sg.print_cascaded_passes()
+
+
+def _calc_tens_to_cps(sg, tensor_rewrites):
+ # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
+ # Returns dictionary tensor -> list of cascaded passes
+ # Note: if cascaded passes are A, B, C, D, and a tensor is output
+ # of A and input to D, then it also consumes SRAM in passes B and C.
+ if "tens_to_cps" in sg.scheduling_info:
+ return sg.scheduling_info["tens_to_cps"]
+ # Determine life-time of tensors
+ min_index = {}
+ max_index = {}
+ index = 0
+ cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
+ for cps in cps_list:
+ for tens in cps.inputs + cps.outputs:
+ if tens in tensor_rewrites:
+ min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
+ max_index[tens] = index
+ index += 1
+ # Convert to affected cps-es
+ tens_to_cps = {}
+ for tens in min_index:
+ tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
+ sg.scheduling_info["tens_to_cps"] = tens_to_cps
+ return tens_to_cps
+
+
+def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
+ # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
+ tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
+ tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
+ # Sort tensors first on life-time (smallest first), then on size (biggest first)
+ tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
+ for _, _, _, tens in tens_list:
+ cps_list = tens_to_cps[tens]
+ if len(cps_list) <= 1:
+ continue
+ sz = tens.storage_size()
+ fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
+ if fits_in_fast_storage:
+ tens.mem_area = arch.fast_storage_mem_area
+ tens.mem_type = MemType.Scratch_fast
+ tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
+ assert tens in tensor_rewrites
+ # Also rewrite reshapes
+ for rewrite_op in tensor_rewrites[tens]:
+ tens2 = rewrite_op.outputs[0]
+ tens2.mem_area = arch.fast_storage_mem_area
+ tens2.mem_type = MemType.Scratch_fast
+ tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
+ for cps in cps_list:
+ cps.sram_used += sz
+
+
+def undo_use_fast_storage(sg, arch):
+ # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
+ tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
+ tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
+ mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
+ for tens, cps_list in tens_to_cps.items():
+ if tens.mem_type == MemType.Scratch_fast:
+ sz = tens.storage_size()
+ tens.mem_area = mem_area
+ tens.mem_type = MemType.Scratch
+ # Also undo reshapes
+ for rewrite_op in tensor_rewrites[tens]:
+ tens2 = rewrite_op.outputs[0]
+ tens2.mem_area = mem_area
+ tens2.mem_type = MemType.Scratch
+ for cps in cps_list:
+ cps.sram_used -= sz