diff options
author | Tim Hall <tim.hall@arm.com> | 2023-07-06 11:42:02 +0100 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2023-07-06 12:49:00 +0100 |
commit | 762d3acb93c597b5299864f9b3ee551107978ee2 (patch) | |
tree | 94eee7b497973c31c73514407084fefaf51515bd /ethosu | |
parent | 2d54e5cd9f86dcabd9f419f196245af4479abcd5 (diff) | |
download | ethos-u-vela-762d3acb93c597b5299864f9b3ee551107978ee2.tar.gz |
MLBEDSW-7832: test_tflite_model_semantic converting array to scalar
- The problem is that the axis value can be either a scalar or an
array containing a single element
- The solution is to check the length of the shape because the size
attribute returns the same value for both cases
- This did not show up before because pytest warnings were not being
treated as errors
- Removed pre-commit pytest option that caused tests to be searched for
from the root directory
- Updated pyproject.toml pytest options to explicitly specify the test
directories, and to treat warnings as errors
Change-Id: I037054768e5c34f253b6062eadba1c3419ff65e4
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 24c0794a..444c04ad 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -436,7 +436,8 @@ class TFLiteSemantic: axis_tens = op.inputs[0] input_tens = op.inputs[1] dims = len(input_tens.shape) - axis = int(axis_tens.values) + # handle axis being a scalar or 1-D array + axis = int(axis_tens.values) if len(axis_tens.values.shape) == 0 else int(axis_tens.values[0]) axis += dims if axis < 0 else 0 valid = 0 <= axis < dims return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}" @@ -448,7 +449,8 @@ class TFLiteSemantic: axis_tens = op.inputs[0] input_tens = op.inputs[1] dims = len(input_tens.shape) - axis = int(axis_tens.values) + # handle axis being a scalar or 1-D array + axis = int(axis_tens.values) if len(axis_tens.values.shape) == 0 else int(axis_tens.values[0]) axis += dims if axis < 0 else 0 valid = input_tens.shape[axis] % num_splits == 0 return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}" |