aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 642f1349..7c60368d 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1472,8 +1472,8 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
dims = len(shape)
# Height and width axes have different index depending on dimensions
- if axis.shape == []: # single axis
- axis = int(axis.values)
+ if len(axis.shape) <= 1: # single axis
+ axis = int(axis.values) if len(axis.shape) == 0 else axis.values[0]
if dims in (2, 3):
if axis == 0:
h, w = shape[axis], 1