diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 18 |
1 files changed, 7 insertions, 11 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 73137feb..af8695b7 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -289,18 +289,14 @@ def convert_resize_1x1_to_add(op): op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor op.name = op.name + "_add" # Create an input tensor filled with zeros + name = op.inputs[1].name + "_add" + dtype = op.inputs[0].dtype shape = op.ofm_shapes[0].as_list() - tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add") - tens.values = np.zeros(shape, tens.dtype.as_numpy_type()) - tens.quantization = QuantizationParameters(0.0, 255.0) - tens.quantization.scale_f32 = 1.0 - tens.quantization.zero_point = 0 - tens.consumer_list = [op] - tens_op = op.inputs[1].ops[0] - tens_op.set_output_tensor(tens) - # Set the add inputs - op.inputs[1] = op.inputs[0] - op.inputs[0] = tens + values = np.zeros(shape, dtype.as_numpy_type()) + quantization = QuantizationParameters(0.0, 255.0) + quantization.scale_f32 = 1.0 + quantization.zero_point = 0 + op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 1) op.set_ifm_ofm_shapes() DebugDatabase.add_optimised(op, op) |