aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 5bf2c459..dfa27199 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -1040,11 +1040,11 @@ class SupportedOperators:
def constraint_mean_axis(op):
"Axis indices must correspond to height and width axes"
dims = len(op.inputs[0].shape)
- axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
if dims == 2 or dims == 3:
- valid = axis in (0, 1, [0, 1], [1, 0])
+ valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
elif dims == 4:
- valid = axis in (1, 2, [1, 2], [2, 1])
+ valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
return valid, f"Axis is {axis}"
@classmethod
@@ -1082,7 +1082,7 @@ class SupportedOperators:
keep_dims is set to True and
IFM datatype is int8"""
shape = op.ifm.shape
- axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
# doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
# and constraint_mean_height_width_product
if (