aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 24c0794a..444c04ad 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -436,7 +436,8 @@ class TFLiteSemantic:
axis_tens = op.inputs[0]
input_tens = op.inputs[1]
dims = len(input_tens.shape)
- axis = int(axis_tens.values)
+ # handle axis being a scalar or 1-D array
+ axis = int(axis_tens.values) if len(axis_tens.values.shape) == 0 else int(axis_tens.values[0])
axis += dims if axis < 0 else 0
valid = 0 <= axis < dims
return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}"
@@ -448,7 +449,8 @@ class TFLiteSemantic:
axis_tens = op.inputs[0]
input_tens = op.inputs[1]
dims = len(input_tens.shape)
- axis = int(axis_tens.values)
+ # handle axis being a scalar or 1-D array
+ axis = int(axis_tens.values) if len(axis_tens.values.shape) == 0 else int(axis_tens.values[0])
axis += dims if axis < 0 else 0
valid = input_tens.shape[axis] % num_splits == 0
return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}"