From 75d3402204145731c2ebe0131ee47d966fd95562 Mon Sep 17 00:00:00 2001 From: William Isaksson Date: Thu, 10 Aug 2023 12:22:44 +0000 Subject: MLBEDSW-7832: test_tflite_model_semantic converting array to scalar - now only converts array directly if ndim==0 Signed-off-by: William Isaksson Change-Id: Id23e419bc7dd717f9694013180d4609819fd2f56 --- ethosu/vela/tflite_model_semantic.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index ea7ef4a3..d2e0ba5a 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -472,7 +472,10 @@ class TFLiteSemantic: input_tens = op.inputs[1] dims = len(input_tens.shape) # handle axis being a scalar or 1-D array - axis = int(axis_tens.values) if len(axis_tens.values.shape) == 0 else int(axis_tens.values[0]) + if axis_tens.values.ndim == 0: + axis = int(axis_tens.values) + else: + axis = int(axis_tens.values[0]) axis += dims if axis < 0 else 0 valid = 0 <= axis < dims return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}" @@ -485,7 +488,10 @@ class TFLiteSemantic: input_tens = op.inputs[1] dims = len(input_tens.shape) # handle axis being a scalar or 1-D array - axis = int(axis_tens.values) if len(axis_tens.values.shape) == 0 else int(axis_tens.values[0]) + if axis_tens.values.ndim == 0: + axis = int(axis_tens.values) + else: + axis = int(axis_tens.values[0]) axis += dims if axis < 0 else 0 valid = input_tens.shape[axis] % num_splits == 0 return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}" -- cgit v1.2.1