diff options
author | Diqing Zhong <diqing.zhong@arm.com> | 2022-03-09 12:23:47 +0100 |
---|---|---|
committer | Diqing Zhong <diqing.zhong@arm.com> | 2022-03-11 17:08:15 +0100 |
commit | 1ddb2ed5fdc0a5e9944c5aeafcfc5ed4c07ea5cf (patch) | |
tree | a446c11f5fa8cf39f3037580c78fc5a32427e3e8 | |
parent | f8e353bd0a6e48f27e4c16b7243e403e5dae8d47 (diff) | |
download | ethos-u-vela-1ddb2ed5fdc0a5e9944c5aeafcfc5ed4c07ea5cf.tar.gz |
Vela: Fix diff in mean op
- Extend ifm/ofm dimensions explicitly in mean op
This fix a bug when ifm/ofm shape has different dimensions
e.g. IFM=1x19x18x25 axis=2 OFM=1x19x25,
the ofm_shape should be 1x19x1x25, not 1x1x19x25
- Fix wrong weight shape
Change-Id: I269eb71ea56c09deee2aa6c6433d9b2baa98a113
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 97e30ad6..3815eedd 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1172,7 +1172,9 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): keep_dims = op.attrs.get("keep_dims", False) inp, axis = op.inputs shape = inp.shape + ofm_shape = op.ofm.shape dims = len(shape) + dims_ofm = len(ofm_shape) # Height and width axes have different index depending on dimensions if axis.shape == [] or axis.shape[0] == 1: # single axis @@ -1301,10 +1303,25 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): op.forced_input_quantization = fiq # Change dimensions to 4 - if dims < 4: - shape = [1] + shape - if dims == 2: - shape += [1] + def extend_dims(dim, in_shape): + if dim < 4: + in_shape = [1] + in_shape + if dim == 2: + in_shape += [1] + return in_shape + + if dims < 4 or dims_ofm < 4: + # Fix the ofm dimension when keep_dims is false + # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC + if isinstance(axis, int) and dims_ofm + 1 == dims: + ofm_shape.insert(axis, 1) + elif isinstance(axis, list) and (dims_ofm + len(axis) == dims): + for i in axis: + ofm_shape.insert(i, 1) + shape = extend_dims(dims, shape) + dims_ofm = len(ofm_shape) + ofm_shape = extend_dims(dims_ofm, ofm_shape) + op.set_ifm_ofm_shapes() # If height is greater than max kernel height, reshape from HxW to 1x(HxW) if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool): @@ -1325,7 +1342,8 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): weight_quant.zero_point = 0 # Set weight shape to [H,W,C,B] - weight_shape = shape[1:4] + [shape[0]] + weight_shape = [h, w, shape[3], shape[0]] + # Add unit weight tensor op.set_input_tensor( create_const_tensor( |