From ba69518cef84a495c104e51d100875cdca717a22 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Wed, 26 Aug 2020 17:27:19 +0100 Subject: MLBEDSW-2686: Use NPU tensor format for noop reshapes. - Reshapes that merely add/remove dimensions, rather than re-layout the data need not fall back to NHWC. This commit allows reshapes betweeen NPU operators to use NHCWB16. Signed-off-by: Tim Hall Change-Id: Ieb7745e586bf324e92e741a04b74caf7285f4b8b --- ethosu/vela/scheduler.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 9a8215d5..9b492f01 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -42,7 +42,7 @@ from .tensor import MemType from .tensor import TensorFormat from .tensor import TensorPurpose from .tensor import TensorSubPurpose - +from .numeric_util import full_shape class ParetoMetric(enum.Enum): BwCycMem = 1 @@ -957,21 +957,36 @@ class DynamicProgrammingScheduler: if ps.placement != PassPlacement.Npu: continue for output in ps.outputs: - if output.purpose != TensorPurpose.FeatureMap: + if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16: continue - use_NHCWB16 = not output.avoid_NHCWB16 - - if use_NHCWB16: - # Check consumers, to see if NHCWB16 can be used in the output - for op in output.consumer_list: - if op is None or op.type == "Reshape": - use_NHCWB16 = False + use_NHCWB16 = True + rewrites = [] + for op in output.consumer_list: + if op is None: + use_NHCWB16 = False + elif op.type == "Reshape": + # Detect no-op reshapes by comparing their full input and output tensor shapes. + inshape = full_shape(4, op.inputs[0].shape, 1) + outshape = full_shape(4, op.outputs[0].shape, 1) + # Using NHCWB16 format for a no-op reshape is only an option if subsequent + # consumers do not also need to perform a reshape or if the OFM is going to + # be processed by CPU operations. No-op reshape consumers with empty lists + # (those that have no consumers, or null-consumers used as list terminators) + # must use normal NHWC output. + incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape") for consumer in op.outputs[0].consumer_list + if consumer is not None ] + if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers): + rewrites.append(op) else: - use_NHCWB16 &= op.run_on_npu + use_NHCWB16 = False + else: + use_NHCWB16 &= op.run_on_npu if use_NHCWB16: output.set_format(TensorFormat.NHCWB16, arch) + for rewrite_op in rewrites: + rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch) def schedule_passes(nng, arch, options: SchedulerOptions): -- cgit v1.2.1