aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py59
1 files changed, 57 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 76383a4b..78906374 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -949,7 +949,57 @@ def reorder_depthwise_weights(op, arch, nng):
return op
-def fixup_strided_conv(op: Operation, arch, nng) -> Operation:
+def convert_avg_pool_to_conv2d(op: Operation, arch, nng) -> Operation:
+ """Convert strided Average Pools with stride >= 4 to Conv2D."""
+ if op.type != Op.AvgPool:
+ return op
+
+ stride_x, stride_y = op.get_kernel_stride()
+ # For strides <= 3 no optimization is needed
+ if stride_x <= 3:
+ return op
+ h, w = op.attrs["filter_height"], op.attrs["filter_width"]
+ inputs = op.inputs[0]
+ shape = inputs.shape
+
+ # Set necessary conv2d attributes
+ op.attrs.update(
+ {
+ "stride_h": stride_y,
+ "stride_w": stride_x,
+ "dilation_h_factor": 1,
+ "dilation_w_factor": 1,
+ "strides": (1, stride_y, stride_x, 1),
+ "dilation": (1, 1, 1, 1),
+ }
+ )
+
+ # Change op type
+ op.type = Op.Conv2DBias
+ op.name += "_conv2d"
+
+ op.rounding_mode = RoundingMode.AwayZero
+ shape = [h, w, 1, op.ofm.shape[-1]]
+ weights = np.full(shape, 1)
+ quant = QuantizationParameters(scale_f32=1 / (h * w), zero_point=0)
+ # Add unit weight tensor
+ op.add_input_tensor(
+ create_const_tensor(
+ "weights",
+ shape,
+ inputs.dtype,
+ weights,
+ quantization=quant,
+ ),
+ )
+ op.weights.values = np.reshape(op.inputs[1].values, shape)
+
+ # Set IFM/OFM shapes after changing op type
+ op.set_ifm_ofm_shapes()
+ return op
+
+
+def fixup_strided_conv(op: Operation, arch, nng):
"""Optimize or fixup strided Conv2DBias
Optimization:
Reduce, when possible, the Conv2DBias stride from N with 1 > N > 4 to 1
@@ -1853,7 +1903,11 @@ def fixup_bias_tensors(op, arch, nng, dtype=None):
if dtype is None:
dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
- op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
+ bias_index = op.type.info.indices.biases[0]
+ if bias_index < len(op.inputs):
+ op.set_input_tensor(bias_tensor, bias_index)
+ else:
+ op.add_input_tensor(bias_tensor)
return op
@@ -2349,6 +2403,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
convert_prelu,
convert_mul_max_to_abs_or_lrelu,
convert_lrelu,
+ convert_avg_pool_to_conv2d,
fixup_strided_conv,
convert_hardswish_to_lut,
rewrite_fully_connected_input,