From 1ddb2ed5fdc0a5e9944c5aeafcfc5ed4c07ea5cf Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Wed, 9 Mar 2022 12:23:47 +0100 Subject: Vela: Fix diff in mean op - Extend ifm/ofm dimensions explicitly in mean op This fix a bug when ifm/ofm shape has different dimensions e.g. IFM=1x19x18x25 axis=2 OFM=1x19x25, the ofm_shape should be 1x19x1x25, not 1x1x19x25 - Fix wrong weight shape Change-Id: I269eb71ea56c09deee2aa6c6433d9b2baa98a113 Signed-off-by: Diqing Zhong --- ethosu/vela/tflite_graph_optimiser.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 97e30ad6..3815eedd 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1172,7 +1172,9 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): keep_dims = op.attrs.get("keep_dims", False) inp, axis = op.inputs shape = inp.shape + ofm_shape = op.ofm.shape dims = len(shape) + dims_ofm = len(ofm_shape) # Height and width axes have different index depending on dimensions if axis.shape == [] or axis.shape[0] == 1: # single axis @@ -1301,10 +1303,25 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): op.forced_input_quantization = fiq # Change dimensions to 4 - if dims < 4: - shape = [1] + shape - if dims == 2: - shape += [1] + def extend_dims(dim, in_shape): + if dim < 4: + in_shape = [1] + in_shape + if dim == 2: + in_shape += [1] + return in_shape + + if dims < 4 or dims_ofm < 4: + # Fix the ofm dimension when keep_dims is false + # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC + if isinstance(axis, int) and dims_ofm + 1 == dims: + ofm_shape.insert(axis, 1) + elif isinstance(axis, list) and (dims_ofm + len(axis) == dims): + for i in axis: + ofm_shape.insert(i, 1) + shape = extend_dims(dims, shape) + dims_ofm = len(ofm_shape) + ofm_shape = extend_dims(dims_ofm, ofm_shape) + op.set_ifm_ofm_shapes() # If height is greater than max kernel height, reshape from HxW to 1x(HxW) if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool): @@ -1325,7 +1342,8 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): weight_quant.zero_point = 0 # Set weight shape to [H,W,C,B] - weight_shape = shape[1:4] + [shape[0]] + weight_shape = [h, w, shape[3], shape[0]] + # Add unit weight tensor op.set_input_tensor( create_const_tensor( -- cgit v1.2.1