diff options
author | Rickard Bolin <rickard.bolin@arm.com> | 2021-12-07 09:09:14 +0000 |
---|---|---|
committer | Rickard Bolin <rickard.bolin@arm.com> | 2021-12-16 12:32:29 +0000 |
commit | 7d7cb671f369fd6ffdbddd12f1e29b6503df2c4d (patch) | |
tree | 195b6b46016d37e25d8a0bf36c55501c599445c4 /ethosu/vela/tflite_graph_optimiser.py | |
parent | e3dd2f3c4e09488776d45e8884123385b3e93e2a (diff) | |
download | ethos-u-vela-7d7cb671f369fd6ffdbddd12f1e29b6503df2c4d.tar.gz |
MLBEDSW-5554: Place MEAN op exceeding max height with axis==1 on CPU
Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I87dc5963972a7ef91db467b2ff8e0261e9899372
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index e01433d0..7b10f86a 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1301,8 +1301,8 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): if dims == 2: shape += [1] - # If height is greater than max kernel height, reshape to from HxW to 1x(HxW) - if h > 64: + # If height is greater than max kernel height, reshape from HxW to 1x(HxW) + if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool): shape = [shape[0], 1, h * w, shape[3]] op.ifm_shapes[0] = Shape4D(shape) if h > 256 and op.type == Op.AvgPool: |