aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py26
-rw-r--r--ethosu/vela/tflite_model_semantic.py22
2 files changed, 48 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 7ca1bbda..bea004ae 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -647,3 +647,29 @@ def test_lstm_semantics():
op.inputs.pop()
# Test restored valid configuration
assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_transpose_semantics():
+ # Test valid op
+ ifm = Tensor([2, 4], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [2], DataType.int32, [1, 0])
+ ofm = Tensor([4, 2], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # Test invalid permutation size
+ perm = create_const_tensor("perm", [3], DataType.int32, [1, 0])
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Test invalid permutation values
+ perm = create_const_tensor("perm", [2], DataType.int32, [2, 0])
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Test invalid permutation values
+ perm = create_const_tensor("perm", [2], DataType.int32, [0, -1])
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Test invalid permutation values
+ perm = create_const_tensor("perm", [2], DataType.int32, [1, 0])
+ perm.values = None
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not semantic_checker.is_operator_semantic_valid(op)
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index d9ace1e8..eff40bc5 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -209,6 +209,10 @@ class TFLiteSemantic:
# Exp specific checks
self.specific_constraints[Op.Exp].append(TFLiteSemantic.constraint_input_signed)
+ # Transpose specific checks
+ self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_size)
+ self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_values)
+
def is_operator_semantic_valid(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -833,6 +837,24 @@ class TFLiteSemantic:
extra = ", ".join(extra)
return valid, f"Op has non-variable state tensor(s): {extra}"
+ @staticmethod
+ def constraint_transpose_permutation_size(op):
+ "Permutation array must be a 1D tensor with RANK(IFM) elements"
+ dims = len(op.inputs[0].shape)
+ perm = op.inputs[1]
+ valid = len(perm.shape) == 1 and perm.shape[0] == dims
+ return valid, f"Op has ifm_dimension={dims} and permutation shape {perm.shape}"
+
+ @staticmethod
+ def constraint_transpose_permutation_values(op):
+ "Permutation array must have constant values in the range [0, RANK(IFM))"
+ dims = len(op.inputs[0].shape)
+ perm = op.inputs[1]
+ valid = False
+ if perm.values is not None:
+ valid = not any([val < 0 or val >= dims for val in perm.values])
+ return valid, f"Op has ifm_dimension={dims} and permutation values are: {perm.values}"
+
def tflite_semantic_checker(nng):
semantic_checker = TFLiteSemantic()