aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2023-03-16 16:38:05 +0000
committerRaul Farkas <raul.farkas@arm.com>2023-03-27 16:35:56 +0100
commit72c6a2414205e033279f80b622cdf479c05a4f5b (patch)
tree35dedce67cedd2fe5533cf0beb2942a7f31199e3 /ethosu/vela/tflite_graph_optimiser.py
parent430002df36f79d035e31e8304fb8b176129cd3cc (diff)
downloadethos-u-vela-72c6a2414205e033279f80b622cdf479c05a4f5b.tar.gz
MLBEDSW-6343: Remove op_index constraint
Remove op_index constraint and force linear format for all Conv2D that have strides that can be optimised. Change-Id: Idef3508ab074ea9abeacac030eaaa15a00ad1211 Signed-off-by: Raul Farkas <raul.farkas@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 44f5d6ae..518b6db0 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -942,19 +942,21 @@ def reorder_depthwise_weights(op, arch, nng):
return op
-def fixup_strided_conv(op, arch, nng):
+def fixup_strided_conv(op: Operation, arch, nng):
+ """Optimize or fixup strided Conv2DBias
+ Optimization:
+ Reduce, when possible, the Conv2DBias stride from 2 to 1 by re-shaping
+ both IFM and filter.
+
+ Fixup:
+ Introduce software support for Conv2DBias with stride_width = 4 by
+ reducing it to 1 when possible by re-shaping both IFM and filter.
+ """
if op.type != Op.Conv2DBias:
return op
stride_x, stride_y = op.get_kernel_stride()
weight_tensor = op.weights
ifm_shape = op.ifm_shapes[0]
-
- # Do not optimize if op is not the first in the network and stride is
- # supported by the hardware
- if op.op_index != 0 and stride_x < 4:
- return op
- op.ifm.needs_linear_format = True
-
if (
(stride_x == 2 or stride_x == 4)
and ifm_shape.depth <= 4
@@ -1004,6 +1006,7 @@ def fixup_strided_conv(op, arch, nng):
stride_x = 1
op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
+ op.ifm.force_linear_format = True
return op
@@ -2125,7 +2128,6 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
convert_prelu,
convert_mul_max_to_abs_or_lrelu,
convert_lrelu,
- fixup_strided_conv,
convert_hardswish_to_lut,
rewrite_fully_connected_input,
convert_batched_fc_shape,
@@ -2139,6 +2141,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
convert_tanh_sigmoid_to_lut,
replace_pad_by_hw_pad,
fixup_dilation_gt2,
+ fixup_strided_conv,
]
for idx, sg in enumerate(nng.subgraphs):