aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_model_semantic.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-11-13 10:23:32 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-11-13 15:08:12 +0100
commitf418e832ffd5a10f549aa07a0c9c59406a374ffe (patch)
treeb91d152f48c346bee172a0f554deb0570956c44f /ethosu/vela/test/test_tflite_model_semantic.py
parenta8fda88bced0d11441467b6798885101d41ac8e9 (diff)
downloadethos-u-vela-f418e832ffd5a10f549aa07a0c9c59406a374ffe.tar.gz
MLBEDSW-8317: Add semantic checks for Transpose
- Added semantic checks for Transpose - Added unit tests for semantic checks - Updated SUPPORTED_OPS.md Change-Id: I3fcf13120f4b6811f8de27711996cdb9c19c9847 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_tflite_model_semantic.py')
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 7ca1bbda..bea004ae 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -647,3 +647,29 @@ def test_lstm_semantics():
op.inputs.pop()
# Test restored valid configuration
assert semantic_checker.is_operator_semantic_valid(op)
+
+
+def test_transpose_semantics():
+ # Test valid op
+ ifm = Tensor([2, 4], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [2], DataType.int32, [1, 0])
+ ofm = Tensor([4, 2], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert semantic_checker.is_operator_semantic_valid(op)
+ # Test invalid permutation size
+ perm = create_const_tensor("perm", [3], DataType.int32, [1, 0])
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Test invalid permutation values
+ perm = create_const_tensor("perm", [2], DataType.int32, [2, 0])
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Test invalid permutation values
+ perm = create_const_tensor("perm", [2], DataType.int32, [0, -1])
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not semantic_checker.is_operator_semantic_valid(op)
+ # Test invalid permutation values
+ perm = create_const_tensor("perm", [2], DataType.int32, [1, 0])
+ perm.values = None
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not semantic_checker.is_operator_semantic_valid(op)