diff options
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r-- | ethosu/vela/scheduler.py | 172 |
1 files changed, 129 insertions, 43 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 24453d8c..5c2ddabb 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -959,52 +959,66 @@ class DynamicProgrammingScheduler: self.sg.cascaded_passes = cascaded_passes self.sg.build_cascaded_pass_links() - if self.options.use_nhcwb16_between_cascaded_passes: - # Check if NHCWB16 can be used in between cascaded passes - # (NHCWB16 within cascaded passes has been handled earlier in this function) - if self.sg.placement == PassPlacement.Npu: - last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op - for ps in self.sg.cascaded_passes: - if ps.placement != PassPlacement.Npu: + # Check if NHCWB16 and/or fast storage can be used in between cascaded passes + # (NHCWB16 within cascaded passes has been handled earlier in this function) + if self.sg.placement == PassPlacement.Npu: + # Dictionary tensor -> list of ops, containing feature maps that can be attempted + # to be moved to fast storage + fast_storage_tensor_rewrites = {} + last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op + for ps in self.sg.cascaded_passes: + if ps.placement != PassPlacement.Npu: + continue + for output in ps.outputs: + if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16: continue - for output in ps.outputs: - if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16: - continue - use_NHCWB16 = True - rewrites = [] - for op in output.consumer_list: - if op is None or (op.type == "ReduceSum" and output.dtype == DataType.int32): - use_NHCWB16 = False - elif op.type == "Reshape": - # Detect no-op reshapes by comparing their full input and output tensor shapes. - inshape = full_shape(4, op.inputs[0].shape, 1) - outshape = full_shape(4, op.outputs[0].shape, 1) - # Using NHCWB16 format for a no-op reshape is only an option if subsequent - # consumers do not also need to perform a reshape or if the OFM is going to - # be processed by CPU operations. No-op reshape consumers with empty lists - # (those that have no consumers, or null-consumers used as list terminators) - # must use normal NHWC output. - incompatible_consumers = [ - ( - not consumer.run_on_npu - or consumer.type == "Reshape" - or (consumer is last_op_in_subgraph) - ) - for consumer in op.outputs[0].consumer_list - if consumer is not None - ] - if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers): - rewrites.append(op) - else: - use_NHCWB16 = False + use_NHCWB16 = True + use_fast_storage = True + rewrites = [] + for op in output.consumer_list: + if op is None: + use_NHCWB16 = False + use_fast_storage = False + continue + if op.type == "ReduceSum" and output.dtype == DataType.int32: + use_NHCWB16 = False + elif op.type == "Reshape": + # Detect no-op reshapes by comparing their full input and output tensor shapes. + inshape = full_shape(4, op.inputs[0].shape, 1) + outshape = full_shape(4, op.outputs[0].shape, 1) + # Using NHCWB16 format for a no-op reshape is only an option if subsequent + # consumers do not also need to perform a reshape or if the OFM is going to + # be processed by CPU operations. No-op reshape consumers with empty lists + # (those that have no consumers, or null-consumers used as list terminators) + # must use normal NHWC output. + incompatible_consumers = [ + ( + not consumer.run_on_npu + or consumer.type == "Reshape" + or (consumer is last_op_in_subgraph) + ) + for consumer in op.outputs[0].consumer_list + if consumer is not None + ] + if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers): + rewrites.append(op) else: - use_NHCWB16 &= op.run_on_npu - - if use_NHCWB16: - output.set_format(TensorFormat.NHCWB16, arch) - for rewrite_op in rewrites: - rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch) + use_NHCWB16 = False + use_fast_storage = False + use_NHCWB16 &= op.run_on_npu + use_fast_storage &= op.run_on_npu + + if use_fast_storage: + fast_storage_tensor_rewrites[output] = rewrites + if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes: + output.set_format(TensorFormat.NHCWB16, arch) + for rewrite_op in rewrites: + rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch) + if self.feature_maps_not_in_fast_storage: + # Remember feature maps that can be moved to fast storage for later use + # in use_fast_storage_for_feature_maps + self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites def schedule_passes(nng, arch, options: SchedulerOptions): @@ -1027,3 +1041,75 @@ def schedule_passes(nng, arch, options: SchedulerOptions): if options.verbose_schedule: sg.print_cascaded_passes() + + +def _calc_tens_to_cps(sg, tensor_rewrites): + # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption. + # Returns dictionary tensor -> list of cascaded passes + # Note: if cascaded passes are A, B, C, D, and a tensor is output + # of A and input to D, then it also consumes SRAM in passes B and C. + if "tens_to_cps" in sg.scheduling_info: + return sg.scheduling_info["tens_to_cps"] + # Determine life-time of tensors + min_index = {} + max_index = {} + index = 0 + cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu] + for cps in cps_list: + for tens in cps.inputs + cps.outputs: + if tens in tensor_rewrites: + min_index[tens] = min(index, min_index.get(tens, len(cps_list))) + max_index[tens] = index + index += 1 + # Convert to affected cps-es + tens_to_cps = {} + for tens in min_index: + tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1] + sg.scheduling_info["tens_to_cps"] = tens_to_cps + return tens_to_cps + + +def use_fast_storage_for_feature_maps(sg, sram_limit, arch): + # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes. + tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {}) + tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites) + # Sort tensors first on life-time (smallest first), then on size (biggest first) + tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps]) + for _, _, _, tens in tens_list: + cps_list = tens_to_cps[tens] + if len(cps_list) <= 1: + continue + sz = tens.storage_size() + fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list]) + if fits_in_fast_storage: + tens.mem_area = arch.fast_storage_mem_area + tens.mem_type = MemType.Scratch_fast + tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None) + assert tens in tensor_rewrites + # Also rewrite reshapes + for rewrite_op in tensor_rewrites[tens]: + tens2 = rewrite_op.outputs[0] + tens2.mem_area = arch.fast_storage_mem_area + tens2.mem_type = MemType.Scratch_fast + tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None) + for cps in cps_list: + cps.sram_used += sz + + +def undo_use_fast_storage(sg, arch): + # Undoes the effects of a previous call to use_fast_storage_for_feature_maps + tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {}) + tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites) + mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] + for tens, cps_list in tens_to_cps.items(): + if tens.mem_type == MemType.Scratch_fast: + sz = tens.storage_size() + tens.mem_area = mem_area + tens.mem_type = MemType.Scratch + # Also undo reshapes + for rewrite_op in tensor_rewrites[tens]: + tens2 = rewrite_op.outputs[0] + tens2.mem_area = mem_area + tens2.mem_type = MemType.Scratch + for cps in cps_list: + cps.sram_used -= sz |