diff options
author | Johan Alfven <johan.alfven@arm.com> | 2023-11-13 10:23:32 +0100 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2023-11-13 15:08:12 +0100 |
commit | f418e832ffd5a10f549aa07a0c9c59406a374ffe (patch) | |
tree | b91d152f48c346bee172a0f554deb0570956c44f /ethosu/vela | |
parent | a8fda88bced0d11441467b6798885101d41ac8e9 (diff) | |
download | ethos-u-vela-f418e832ffd5a10f549aa07a0c9c59406a374ffe.tar.gz |
MLBEDSW-8317: Add semantic checks for Transpose
- Added semantic checks for Transpose
- Added unit tests for semantic checks
- Updated SUPPORTED_OPS.md
Change-Id: I3fcf13120f4b6811f8de27711996cdb9c19c9847
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/test/test_tflite_model_semantic.py | 26 | ||||
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 22 |
2 files changed, 48 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py index 7ca1bbda..bea004ae 100644 --- a/ethosu/vela/test/test_tflite_model_semantic.py +++ b/ethosu/vela/test/test_tflite_model_semantic.py @@ -647,3 +647,29 @@ def test_lstm_semantics(): op.inputs.pop() # Test restored valid configuration assert semantic_checker.is_operator_semantic_valid(op) + + +def test_transpose_semantics(): + # Test valid op + ifm = Tensor([2, 4], DataType.int8, "ifm") + perm = create_const_tensor("perm", [2], DataType.int32, [1, 0]) + ofm = Tensor([4, 2], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation size + perm = create_const_tensor("perm", [3], DataType.int32, [1, 0]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [2, 0]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [0, -1]) + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) + # Test invalid permutation values + perm = create_const_tensor("perm", [2], DataType.int32, [1, 0]) + perm.values = None + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not semantic_checker.is_operator_semantic_valid(op) diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index d9ace1e8..eff40bc5 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -209,6 +209,10 @@ class TFLiteSemantic: # Exp specific checks self.specific_constraints[Op.Exp].append(TFLiteSemantic.constraint_input_signed) + # Transpose specific checks + self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_size) + self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_values) + def is_operator_semantic_valid(self, op): ext_type = optype_to_builtintype(op.type) @@ -833,6 +837,24 @@ class TFLiteSemantic: extra = ", ".join(extra) return valid, f"Op has non-variable state tensor(s): {extra}" + @staticmethod + def constraint_transpose_permutation_size(op): + "Permutation array must be a 1D tensor with RANK(IFM) elements" + dims = len(op.inputs[0].shape) + perm = op.inputs[1] + valid = len(perm.shape) == 1 and perm.shape[0] == dims + return valid, f"Op has ifm_dimension={dims} and permutation shape {perm.shape}" + + @staticmethod + def constraint_transpose_permutation_values(op): + "Permutation array must have constant values in the range [0, RANK(IFM))" + dims = len(op.inputs[0].shape) + perm = op.inputs[1] + valid = False + if perm.values is not None: + valid = not any([val < 0 or val >= dims for val in perm.values]) + return valid, f"Op has ifm_dimension={dims} and permutation values are: {perm.values}" + def tflite_semantic_checker(nng): semantic_checker = TFLiteSemantic() |