From 73320a48dfa711f5938b0e3d8e03b9858558b899 Mon Sep 17 00:00:00 2001 From: Dwight Lidman Date: Thu, 5 Nov 2020 10:34:41 +0100 Subject: MLBEDSW-3377: fixup_stridedslice_output may silently change CPU ops This commit removes the constraint on all tensor shapes matching the OFM shape. The motivation is that this constraint essentially only checks that the fixup function has run. This means that it removes the possibility for the fixup function to run after the supported operator check and this effectively means that any StridedSlice operator that would be placed on the CPU is still modified by the fixup function. Because the fixup function is moved to after the supported operators check, some unreachable cases are removed from the fixup function. Signed-off-by: Dwight Lidman Change-Id: I7a82126b7de73bd67873b4e6daf53a6767e33d16 --- ethosu/vela/graph_optimiser.py | 16 ++++------------ ethosu/vela/supported_operators.py | 17 ----------------- ethosu/vela/test/test_supported_operators.py | 6 ------ 3 files changed, 4 insertions(+), 35 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 32f97d2f..e31348b5 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -422,19 +422,12 @@ def unfuse_activation_function(op, arch, nng): def fixup_stridedslice_output(tens, arch, nng): op = tens.ops[0] - if op.type == Op.StridedSlice: + if op.run_on_npu and op.type == Op.StridedSlice: reshape_input_shape = tens.shape new_axis_mask = op.attrs["new_axis_mask"] shrink_axis_mask = op.attrs["shrink_axis_mask"] - ellipsis_mask = op.attrs["ellipsis_mask"] - if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0: - # Not supported, will be put on CPU - return tens - if shrink_axis_mask == 0 and new_axis_mask == 0: - # Equal Rank StridedSlice, no need to insert reshape - return tens - elif shrink_axis_mask != 0: + if shrink_axis_mask != 0: n = 0 axis = 0 while shrink_axis_mask: @@ -446,7 +439,6 @@ def fixup_stridedslice_output(tens, arch, nng): assert len(tens.shape) == (len(op.inputs[0].shape) - n) op.attrs["shrink_axis_mask"] = 0 - elif new_axis_mask != 0: n = 0 axis = 0 @@ -1092,7 +1084,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_stridedslice_output], op_rewrite_list, rewrite_unsupported=False, + nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False, ) for idx, sg in enumerate(nng.subgraphs): @@ -1113,7 +1105,7 @@ def optimise_graph_b(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # combined rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_unpack_output, rewrite_concat, rewrite_split], [] + nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [] ) if verbose_graph: diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 04cda1da..3e649e09 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -219,7 +219,6 @@ class SupportedOperators: # StridedSlice specific checks: self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_input_count) self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_inputs_const) - self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_tens_size_matches) self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_stride_values) self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_ellipsis_mask) self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_axis_masks) @@ -727,22 +726,6 @@ class SupportedOperators: extra = ", ".join(extra) return valid, f"Op has non-constant tensors: {extra}" - @staticmethod - def constraint_stridedslice_tens_size_matches(op): - "All Input sizes must match OFM size" - ifm, begin, end, strides = op.inputs - ifm_size = len(ifm.shape) - ofm_size = len(op.ofm.shape) - begin_size = len(begin.values) - end_size = len(end.values) - strides_size = len(strides.values) - valid = ifm_size == ofm_size == begin_size == end_size == strides_size - extra = ( - f"Op has ofm_size={ofm_size}, ifm_size={ifm_size}," - f" begin_size={begin_size}, end_size={end_size} and strides_size={strides_size}" - ) - return valid, extra - @staticmethod def constraint_stridedslice_stride_values(op): "All Strides values must be 1" diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 595ea590..245ebcf9 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -486,12 +486,6 @@ def test_constraint_stridedslice_inputs_const(): assert not support.is_operator_supported(op) -def test_constraint_stridedslice_tens_size_matches(): - op = create_strided_slice() - op.inputs[1].values = [1, 1, 1, 1, 1, 1, 1, 1] - assert not support.is_operator_supported(op) - - def test_constraint_stridedslice_stride_values(): # Unsupported strides op = create_strided_slice() -- cgit v1.2.1