diff options
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 1310ee63..028151fd 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1752,9 +1752,15 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): op.set_ifm_ofm_shapes() # If height is greater than max kernel height, reshape from HxW to 1x(HxW) + weight_shape = None if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool): + # This can only happen and be done for multiple axes, and + # h * w <= 256 for DepthwiseConv2DBias + # h * w <= 4096 for AvgPool + # which is checked in supported ops shape = [shape[0], 1, h * w, shape[3]] op.ifm_shapes[0] = Shape4D(shape) + weight_shape = [1, h * w, shape[3], shape[0]] if h > 256 and op.type == Op.AvgPool: op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w}) @@ -1769,8 +1775,9 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): weight_quant.scale_f32 = weight_scale weight_quant.zero_point = 0 - # Set weight shape to [H,W,C,B] - weight_shape = [h, w, shape[3], shape[0]] + if weight_shape is None: + # Set weight shape to [H,W,C,B] + weight_shape = [h, w, shape[3], shape[0]] # Add unit weight tensor op.set_input_tensor( |