From c7c0b1ba5e7c3dea73d1ab175b03ff188658d27b Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Mon, 26 Oct 2020 11:45:25 +0100 Subject: MLBEDSW-3283: Bug fix: StridedSlice Op is placed on CPU Signed-off-by: Diqing Zhong Change-Id: I91a3b277cda91dca3bad38908d4ed11a4f5d7d5f --- ethosu/vela/graph_optimiser.py | 109 ++++++++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 45 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 4696446..32f97d2 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -420,54 +420,73 @@ def unfuse_activation_function(op, arch, nng): return op +def fixup_stridedslice_output(tens, arch, nng): + op = tens.ops[0] + if op.type == Op.StridedSlice: + reshape_input_shape = tens.shape + new_axis_mask = op.attrs["new_axis_mask"] + shrink_axis_mask = op.attrs["shrink_axis_mask"] + ellipsis_mask = op.attrs["ellipsis_mask"] + + if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0: + # Not supported, will be put on CPU + return tens + if shrink_axis_mask == 0 and new_axis_mask == 0: + # Equal Rank StridedSlice, no need to insert reshape + return tens + elif shrink_axis_mask != 0: + n = 0 + axis = 0 + while shrink_axis_mask: + prev_mask = shrink_axis_mask + n += 1 + shrink_axis_mask &= shrink_axis_mask - 1 + axis = int(math.log2(prev_mask - shrink_axis_mask)) + reshape_input_shape = reshape_input_shape[:axis] + [1] + reshape_input_shape[axis:] + + assert len(tens.shape) == (len(op.inputs[0].shape) - n) + op.attrs["shrink_axis_mask"] = 0 + + elif new_axis_mask != 0: + n = 0 + axis = 0 + while new_axis_mask: + prev_mask = new_axis_mask + n += 1 + new_axis_mask &= new_axis_mask - 1 + axis = int(math.log2(prev_mask - new_axis_mask)) + reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :] + new_axis_mask >>= 1 + + assert len(tens.shape) == (len(op.inputs[0].shape) + n) + op.attrs["new_axis_mask"] = 0 + + # Construct 1 shape tensor to be used by all inserted reshape ops + new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape) + + for idx, out_tens in enumerate(op.outputs): + reshape_in = out_tens.clone("_reshaped") + reshape_in.set_all_shapes(reshape_input_shape) + reshape_in.ops = [op] + + reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx)) + reshape_op.attrs["new_shape"] = reshape_input_shape + reshape_op.inputs = [reshape_in, new_shape_tens] + reshape_op.set_output_tensor(out_tens) + + op.outputs[idx] = reshape_in + + return tens + + def fixup_unpack_output(tens, arch, nng): op = tens.ops[0] - if op.run_on_npu and op.type in set((Op.Unpack, Op.StridedSlice)): + if op.run_on_npu and op.type == Op.Unpack: # Unpack is also referred to as Unstack # Requires the rewrite_split function to be called on the op afterwards - - reshape_input_shape = tens.shape - if op.type == Op.StridedSlice: - new_axis_mask = op.attrs["new_axis_mask"] - shrink_axis_mask = op.attrs["shrink_axis_mask"] - ellipsis_mask = op.attrs["ellipsis_mask"] - - if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0: - # Not supported, will be put on CPU - return tens - if shrink_axis_mask == 0 and new_axis_mask == 0: - # Equal Rank StridedSlice, no need to insert reshape - return tens - elif shrink_axis_mask != 0: - n = 0 - axis = 0 - while shrink_axis_mask: - prev_mask = shrink_axis_mask - n += 1 - shrink_axis_mask &= shrink_axis_mask - 1 - axis = int(math.log2(prev_mask - shrink_axis_mask)) - reshape_input_shape = reshape_input_shape[:axis] + [1] + reshape_input_shape[axis:] - - assert len(tens.shape) == (len(op.inputs[0].shape) - n) - op.attrs["shrink_axis_mask"] = 0 - - elif new_axis_mask != 0: - n = 0 - axis = 0 - while new_axis_mask: - prev_mask = new_axis_mask - n += 1 - new_axis_mask &= new_axis_mask - 1 - axis = int(math.log2(prev_mask - new_axis_mask)) - reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :] - new_axis_mask >>= 1 - - assert len(tens.shape) == (len(op.inputs[0].shape) + n) - op.attrs["new_axis_mask"] = 0 - else: - axis = int(op.attrs["axis"]) - op.type = Op.UnpackReshaped - reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:] + axis = int(op.attrs["axis"]) + op.type = Op.UnpackReshaped + reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:] # Construct 1 shape tensor to be used by all inserted reshape ops new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape) @@ -1073,7 +1092,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False, + nng, sg, arch, [fixup_stridedslice_output], op_rewrite_list, rewrite_unsupported=False, ) for idx, sg in enumerate(nng.subgraphs): -- cgit v1.2.1