aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-09-28 20:06:25 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-28 09:46:11 +0200
commit48e5159e8b34abe91f331d76e746c25b4017a96e (patch)
treea78c984d43574f614c7573e15650ec2b45e91cff
parent8484d6e529bc7828d3e5034cd9dfcfb1ddb0559a (diff)
downloadethos-u-vela-48e5159e8b34abe91f331d76e746c25b4017a96e.tar.gz
MLBEDSW-6975: Updated bypass functionality
- The previous patch the always replaced ifm with ofm introduced unnecessary avg pool ops for some cases. That patch has been reverted and this is a new solution. - Replace ifm with ofm for the following condition: a) Ops that are dependent that the original ifm tensor shape is not changed by the bypass memory op function. b) When the memory op has different IFM and OFM rank. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I16a023e169ae64c5db46f6f88516a5e1ca7ed7ef
-rw-r--r--ethosu/vela/graph_optimiser_util.py67
-rw-r--r--ethosu/vela/tflite_supported_operators.py9
2 files changed, 66 insertions, 10 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index b33851a8..e2ee06b8 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -39,6 +39,10 @@ memory_only_ops = (
Op.Identity,
)
+# Ops that are dependent that the original ifm tensor shape is not changed
+# by the bypass memory op function
+original_ifm_shape_ops = (Op.Mean,)
+
def _avoid_nhcwb16_for_concat(tens):
# If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -195,6 +199,14 @@ def set_ifm_ofm_op_shapes(op, arch, nng):
return op
+def bypass_need_to_keep_ofm_shape(op):
+ # Check if ifm must be replaced by ofm (rank is changed or the op that follow must have original ifm shape)
+ ifm_replaced_by_ofm = any(
+ ofm_cons is not None and ofm_cons.type in original_ifm_shape_ops for ofm_cons in op.ofm.consumer_list
+ ) or len(op.ifm.shape) != len(op.ofm.shape)
+ return ifm_replaced_by_ofm
+
+
def bypass_memory_only_ops(op):
assert op.type in memory_only_ops
ofm = op.ofm
@@ -211,7 +223,7 @@ def bypass_memory_only_ops(op):
# This case should be handled prior to this function
assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
- if ofm_is_sg_ofm or ofm_is_cpu_consumed:
+ if (ifm.shape != ofm.shape) and (ofm_is_sg_ofm or ofm_is_cpu_consumed or bypass_need_to_keep_ofm_shape(op)):
# Bypassed by replacing ifm with ofm
ofm.ops = []
for prev_op in ifm.ops:
@@ -261,6 +273,20 @@ def record_optimised(op, arch):
DebugDatabase.add_optimised(op, op)
+def insert_copy_op_before_op(op):
+ # Create a avg_pool nop op with ifm as input
+ tens = op.ifm
+ copy_tens = tens.clone()
+ copy_op = create_avgpool_nop(f"{tens.name}_avgpool")
+ copy_op.add_input_tensor(tens)
+ copy_op.set_output_tensor(copy_tens)
+ copy_op.set_ifm_ofm_shapes()
+
+ op.set_input_tensor(copy_tens, 0)
+
+ DebugDatabase.add_optimised(op, copy_op)
+
+
def insert_copy_op_after_tens(tens):
tens_cons_list_copy = tens.consumer_list.copy()
@@ -293,6 +319,31 @@ def fix_sg_input_output(op, arch, nng):
# This is also valid when reshape ifm/ofm is produced respectively
# consumed by CPU
+ # Rare case: original_ifm_shape_ops contain ops that are dependent
+ # that the original ifm tensor shape is not changed by the bypass memory
+ # function. If the memory only op ifm is subgraph ifm/ifm is cpu produced
+ # or the ifm is consumed by many, then there is a need to insert an avgpool
+ # NOP before the original_ifm_shape_ops. Also note that the NOP is only inserted
+ # before original_ifm_shape_ops. The above is also true when the memory only
+ # op change the rank between the IFM and OFM.
+ #
+ # Below is an example showing the case when there is a need for an AVG NOP
+ # when RESHAPE is bypassed by replacing IFM with OFM.
+ #
+ # Converts to And in bypass_memory
+ # ---> --->
+ # -----ADD----- -----ADD----- -----ADD-----
+ # | | | | | |
+ # 1x6x6x10 1x6x6x10 1x6x6x10 1x6x6x10 1x6x6x10 1x6x6x10
+ # RESHAPE MEAN AVG POOL MEAN AVG POOL MEAN
+ # | | | |
+ # 1x20x3x6 1x6x6x10 1x20x3x6
+ # MEAN RESHAPE MEAN
+ # |
+ # 1x20x3x6
+ # MEAN
+ ifm_has_multiple_cons = len(op.ifm.consumer_list) > 1
+
# Check if operator ifm/ofm are sg ifm/ofm
ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
@@ -301,6 +352,20 @@ def fix_sg_input_output(op, arch, nng):
ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+ if bypass_need_to_keep_ofm_shape(op):
+ # Bypass need to keep OFM shape
+ if ifm_has_multiple_cons:
+ # Rare case:
+ # IFM need to persist due to multiple consumers and copy op is needed
+ # OFM will replace IFM for the memory only op
+ insert_copy_op_before_op(op)
+ elif not (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+ # Only one consumer and OFM is not subgraph output or cpu consumed,
+ # safe to replace ifm.shape by ofm.shape
+ # IFM can then replace OFM for the memory only op and no copy op is needed
+ op.ifm.shape = op.ofm.shape
+
+ # Special case when when OFM is sg_ofm or cpu_consumed
if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
# Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the memory only operator.
insert_copy_op_after_tens(op.ifm)
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index c394778b..b8fe4b6a 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -319,7 +319,6 @@ class TFLiteSupportedOperators:
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
- self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_before_mean)
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -880,11 +879,3 @@ class TFLiteSupportedOperators:
extra = ", ".join(extra)
return valid, f"Op has non-const input(s): {extra}"
-
- @staticmethod
- def constraint_reshape_before_mean(op):
- "Reshape on NPU not supported before MEAN operator"
- for next_op in op.outputs[0].consumers():
- if next_op is not None and next_op.type == Op.Mean:
- return False, ""
- return True, ""