aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiqing Zhong <diqing.zhong@arm.com>2022-03-09 12:23:47 +0100
committerDiqing Zhong <diqing.zhong@arm.com>2022-03-11 17:08:15 +0100
commit1ddb2ed5fdc0a5e9944c5aeafcfc5ed4c07ea5cf (patch)
treea446c11f5fa8cf39f3037580c78fc5a32427e3e8
parentf8e353bd0a6e48f27e4c16b7243e403e5dae8d47 (diff)
downloadethos-u-vela-1ddb2ed5fdc0a5e9944c5aeafcfc5ed4c07ea5cf.tar.gz
Vela: Fix diff in mean op
- Extend ifm/ofm dimensions explicitly in mean op This fix a bug when ifm/ofm shape has different dimensions e.g. IFM=1x19x18x25 axis=2 OFM=1x19x25, the ofm_shape should be 1x19x1x25, not 1x1x19x25 - Fix wrong weight shape Change-Id: I269eb71ea56c09deee2aa6c6433d9b2baa98a113 Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py28
1 files changed, 23 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 97e30ad..3815eed 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1172,7 +1172,9 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
keep_dims = op.attrs.get("keep_dims", False)
inp, axis = op.inputs
shape = inp.shape
+ ofm_shape = op.ofm.shape
dims = len(shape)
+ dims_ofm = len(ofm_shape)
# Height and width axes have different index depending on dimensions
if axis.shape == [] or axis.shape[0] == 1: # single axis
@@ -1301,10 +1303,25 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
op.forced_input_quantization = fiq
# Change dimensions to 4
- if dims < 4:
- shape = [1] + shape
- if dims == 2:
- shape += [1]
+ def extend_dims(dim, in_shape):
+ if dim < 4:
+ in_shape = [1] + in_shape
+ if dim == 2:
+ in_shape += [1]
+ return in_shape
+
+ if dims < 4 or dims_ofm < 4:
+ # Fix the ofm dimension when keep_dims is false
+ # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
+ if isinstance(axis, int) and dims_ofm + 1 == dims:
+ ofm_shape.insert(axis, 1)
+ elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
+ for i in axis:
+ ofm_shape.insert(i, 1)
+ shape = extend_dims(dims, shape)
+ dims_ofm = len(ofm_shape)
+ ofm_shape = extend_dims(dims_ofm, ofm_shape)
+ op.set_ifm_ofm_shapes()
# If height is greater than max kernel height, reshape from HxW to 1x(HxW)
if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
@@ -1325,7 +1342,8 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
weight_quant.zero_point = 0
# Set weight shape to [H,W,C,B]
- weight_shape = shape[1:4] + [shape[0]]
+ weight_shape = [h, w, shape[3], shape[0]]
+
# Add unit weight tensor
op.set_input_tensor(
create_const_tensor(