diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 5bf2c459..dfa27199 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -1040,11 +1040,11 @@ class SupportedOperators: def constraint_mean_axis(op): "Axis indices must correspond to height and width axes" dims = len(op.inputs[0].shape) - axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values) + axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values) if dims == 2 or dims == 3: - valid = axis in (0, 1, [0, 1], [1, 0]) + valid = axis in (0, 1, [0], [1], [0, 1], [1, 0]) elif dims == 4: - valid = axis in (1, 2, [1, 2], [2, 1]) + valid = axis in (1, 2, [1], [2], [1, 2], [2, 1]) return valid, f"Axis is {axis}" @classmethod @@ -1082,7 +1082,7 @@ class SupportedOperators: keep_dims is set to True and IFM datatype is int8""" shape = op.ifm.shape - axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values) + axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values) # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool # and constraint_mean_height_width_product if ( |