diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 64 |
1 files changed, 63 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 393a8323..6297fca2 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -2080,6 +2080,68 @@ def fixup_dilation_gt2(op, arch, nng): return op +def fixup_reshape(op, arch, nng): + def _get_explicit_shape(implicit_shape, total_size): + # the explicit shape is a copy of the implicit shape but with the special -1 (remaining size) value converted to + # the appropriate value + if implicit_shape is None: + return None + + explicit_shape = list(implicit_shape) + if -1 in explicit_shape: + explicit_shape[explicit_shape.index(-1)] = int(total_size / abs(np.prod(implicit_shape))) + + return explicit_shape + + if op.type == Op.Reshape: + ifm_tensor, _, ofm_tensor = op.get_ifm_ifm2_ofm() + ifm_size = ifm_tensor.elements() + ofm_shape = ofm_tensor.shape + + new_shape_tensor_shape = op.inputs[1].values.flatten() if len(op.inputs) > 1 else None + new_shape_tensor_shape = _get_explicit_shape(new_shape_tensor_shape, ifm_size) + + new_shape_attribute = op.attrs.get("new_shape", None) + new_shape_attribute = _get_explicit_shape(new_shape_attribute, ifm_size) + + # if present the new shape tensor overrides the new_shape attribute + if new_shape_tensor_shape is not None: + # check tensor + if not np.array_equal(new_shape_tensor_shape, ofm_shape): + print( + f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new shape tensor" + f" ({new_shape_tensor_shape}) that does not match output tensor shape {ofm_shape}. Will use output" + f" tensor shape." + ) + elif new_shape_attribute is not None: + # check attribute + if not np.array_equal(new_shape_attribute, ofm_shape): + print( + f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new_shape attribute" + f" ({new_shape_attribute}) that does not match output tensor shape {ofm_shape}. Will use output" + f" tensor shape." + ) + else: + print( + f"Warning: {optype_to_builtintype(op.type)} '{op.name}' does not have a new shape tensor or a new_shape" + f" attribute. Will use output tensor shape {ofm_shape}." + ) + + # force new shape tensor to output shape + new_shape_tensor = create_const_tensor( + op.name + "_new_shape", [len(ofm_shape)], DataType.int32, np.array(ofm_shape, np.int32) + ) + if len(op.inputs) > 1: + op.set_input_tensor(new_shape_tensor, 1) + else: + op.add_input_tensor(new_shape_tensor) + + # force new_shape attribute to output shape + op.attrs["new_shape"] = ofm_shape + + return op + + def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op) return op @@ -2104,7 +2166,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): ) # Pre-processing step - pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes] + pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape] for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( |