aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwilisa01 <william.isaksson@arm.com>2023-01-12 08:17:23 +0000
committerWilliam Isaksson <william.isaksson@arm.com>2023-02-10 11:09:37 +0000
commit8289d519b0e96b3178cfa8d7be4442f3d963ed0a (patch)
tree49ee76bcb8d008e93a43011e0f7cc4ed52251ced
parent428a8d54f574a73804274e53e61f711aebc25a0a (diff)
downloadethos-u-vela-8289d519b0e96b3178cfa8d7be4442f3d963ed0a.tar.gz
MLBEDSW-4960: convert_resizebilinear_1x1_to_add creates constant output tensor
Sets second input tensor of resize op to be constant and refactored function. Signed-off-by: wilisa01 <william.isaksson@arm.com> Change-Id: I496764f18b4c1ae0fa1a828dd7a90e937a42d41b
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py18
1 files changed, 7 insertions, 11 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 73137feb..af8695b7 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -289,18 +289,14 @@ def convert_resize_1x1_to_add(op):
op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
op.name = op.name + "_add"
# Create an input tensor filled with zeros
+ name = op.inputs[1].name + "_add"
+ dtype = op.inputs[0].dtype
shape = op.ofm_shapes[0].as_list()
- tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
- tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
- tens.quantization = QuantizationParameters(0.0, 255.0)
- tens.quantization.scale_f32 = 1.0
- tens.quantization.zero_point = 0
- tens.consumer_list = [op]
- tens_op = op.inputs[1].ops[0]
- tens_op.set_output_tensor(tens)
- # Set the add inputs
- op.inputs[1] = op.inputs[0]
- op.inputs[0] = tens
+ values = np.zeros(shape, dtype.as_numpy_type())
+ quantization = QuantizationParameters(0.0, 255.0)
+ quantization.scale_f32 = 1.0
+ quantization.zero_point = 0
+ op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 1)
op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, op)