aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-01-31 10:26:26 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-02-06 10:09:58 +0100
commit12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f (patch)
treebbbe6eadb20249fb5974e8753d324b66814d8184 /ethosu/vela/tflite_model_semantic.py
parent4b7179936d659a6d4abd6b7659c2cc05c5a845fb (diff)
downloadethos-u-vela-12e481147de461e3ea63a8b1dcbc1b66b0fe8e6f.tar.gz
MLBEDSW-7284: MLCE: Fix assert for faulty Split op
- An assert in Vela is triggered when the number of splits does not evenly divide the input.shape[axis] value and the split offsets are calculated wrongly. - The fix is to add the same constraints as in the reference kernel and only run the Split op on the NPU when the criterias are fulfilled. - Modified test to reflect the new constraints - Updated SUPPORTED_OPS.md Change-Id: I4103ff4a3fdf9a813f5fcb7f51081b859e611100 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py29
1 files changed, 28 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index dc3b8185..2851ab16 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -155,6 +155,10 @@ class TFLiteSemantic:
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
+ # Split specific checks:
+ self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_axis)
+ self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_num_splits)
+
# SplitV specific checks:
self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
@@ -396,6 +400,29 @@ class TFLiteSemantic:
return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
@staticmethod
+ def constraint_split_axis(op):
+ "Axis value must be in the range [-RANK(IFM) to +RANK(IFM))"
+ axis_tens = op.inputs[0]
+ input_tens = op.inputs[1]
+ dims = len(input_tens.shape)
+ axis = int(axis_tens.values)
+ axis += dims if axis < 0 else 0
+ valid = 0 <= axis < dims
+ return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}"
+
+ @staticmethod
+ def constraint_split_num_splits(op):
+ "Axis must be divisible by number of splits"
+ num_splits = op.attrs.get("num_splits")
+ axis_tens = op.inputs[0]
+ input_tens = op.inputs[1]
+ dims = len(input_tens.shape)
+ axis = int(axis_tens.values)
+ axis += dims if axis < 0 else 0
+ valid = input_tens.shape[axis] % num_splits == 0
+ return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}"
+
+ @staticmethod
def constraint_splitv_inferred(op):
"Only one size is allowed to be inferred"
sizes = op.inputs[1].values