aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py64
1 files changed, 63 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 393a8323..6297fca2 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2080,6 +2080,68 @@ def fixup_dilation_gt2(op, arch, nng):
return op
+def fixup_reshape(op, arch, nng):
+ def _get_explicit_shape(implicit_shape, total_size):
+ # the explicit shape is a copy of the implicit shape but with the special -1 (remaining size) value converted to
+ # the appropriate value
+ if implicit_shape is None:
+ return None
+
+ explicit_shape = list(implicit_shape)
+ if -1 in explicit_shape:
+ explicit_shape[explicit_shape.index(-1)] = int(total_size / abs(np.prod(implicit_shape)))
+
+ return explicit_shape
+
+ if op.type == Op.Reshape:
+ ifm_tensor, _, ofm_tensor = op.get_ifm_ifm2_ofm()
+ ifm_size = ifm_tensor.elements()
+ ofm_shape = ofm_tensor.shape
+
+ new_shape_tensor_shape = op.inputs[1].values.flatten() if len(op.inputs) > 1 else None
+ new_shape_tensor_shape = _get_explicit_shape(new_shape_tensor_shape, ifm_size)
+
+ new_shape_attribute = op.attrs.get("new_shape", None)
+ new_shape_attribute = _get_explicit_shape(new_shape_attribute, ifm_size)
+
+ # if present the new shape tensor overrides the new_shape attribute
+ if new_shape_tensor_shape is not None:
+ # check tensor
+ if not np.array_equal(new_shape_tensor_shape, ofm_shape):
+ print(
+ f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new shape tensor"
+ f" ({new_shape_tensor_shape}) that does not match output tensor shape {ofm_shape}. Will use output"
+ f" tensor shape."
+ )
+ elif new_shape_attribute is not None:
+ # check attribute
+ if not np.array_equal(new_shape_attribute, ofm_shape):
+ print(
+ f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new_shape attribute"
+ f" ({new_shape_attribute}) that does not match output tensor shape {ofm_shape}. Will use output"
+ f" tensor shape."
+ )
+ else:
+ print(
+ f"Warning: {optype_to_builtintype(op.type)} '{op.name}' does not have a new shape tensor or a new_shape"
+ f" attribute. Will use output tensor shape {ofm_shape}."
+ )
+
+ # force new shape tensor to output shape
+ new_shape_tensor = create_const_tensor(
+ op.name + "_new_shape", [len(ofm_shape)], DataType.int32, np.array(ofm_shape, np.int32)
+ )
+ if len(op.inputs) > 1:
+ op.set_input_tensor(new_shape_tensor, 1)
+ else:
+ op.add_input_tensor(new_shape_tensor)
+
+ # force new_shape attribute to output shape
+ op.attrs["new_shape"] = ofm_shape
+
+ return op
+
+
def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
return op
@@ -2104,7 +2166,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
)
# Pre-processing step
- pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes]
+ pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape]
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(