aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_supported_operators.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-10-28 16:04:46 +0200
committerJohan Alfven <johan.alfven@arm.com>2023-11-09 11:59:58 +0100
commita8fda88bced0d11441467b6798885101d41ac8e9 (patch)
tree807de7fa4eee48720255fbed4a605218f8612f6a /ethosu/vela/test/test_tflite_supported_operators.py
parent4bf0cdf58416edc030ae7507ace95224785e4aa8 (diff)
downloadethos-u-vela-a8fda88bced0d11441467b6798885101d41ac8e9.tar.gz
MLBEDSW-8290: MLCE: Add TRANSPOSE support3.10.0.rc1
- Added graph optimiser function to convert TRANSPOSE op into an AvgPool op with swapped stride for height and width - Added TRANSPOSE supported op check - Added unit tests for TRANSPOSE supported op check - Updated SUPPORTED_OPS.md - Fixed problem in pass packing when optimizing the pass list. Old problem, but now seen when moving TRANSPOSE from cpu. Change-Id: I0a0ef420b0fb8241090c2e2434622881105cde15 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_tflite_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index a433fb8d..e65717a8 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -759,3 +759,86 @@ def test_constraint_slice_inputs_const():
op.set_input_tensor(begin, 1)
op.set_input_tensor(begin, 2)
assert support.is_operator_supported(op)
+
+
+def test_constraint_transpose():
+ # Test supported op IFM rank 2
+ ifm = Tensor([2, 4], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [2], DataType.int32, [1, 0])
+ ofm = Tensor([4, 2], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert support.is_operator_supported(op)
+ # Test supported op IFM rank 3
+ ifm = Tensor([2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [3], DataType.int32, [1, 0, 2])
+ ofm = Tensor([4, 2, 6], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert support.is_operator_supported(op)
+ ifm = Tensor([1, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [3], DataType.int32, [0, 2, 1])
+ ofm = Tensor([1, 6, 4], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert support.is_operator_supported(op)
+ ifm = Tensor([2, 1, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [3], DataType.int32, [2, 1, 0])
+ ofm = Tensor([6, 1, 2], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert support.is_operator_supported(op)
+ # Test supported op IFM rank 4
+ ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [0, 2, 1, 3])
+ ofm = Tensor([1, 4, 2, 6], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert support.is_operator_supported(op)
+ ifm = Tensor([1, 1, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [0, 1, 3, 2])
+ ofm = Tensor([1, 1, 6, 4], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert support.is_operator_supported(op)
+ ifm = Tensor([1, 2, 1, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [0, 3, 2, 1])
+ ofm = Tensor([1, 6, 1, 2], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert support.is_operator_supported(op)
+ # Test not supported op IFM rank 3
+ ifm = Tensor([2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [3], DataType.int32, [0, 2, 1])
+ ofm = Tensor([2, 6, 4], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not support.is_operator_supported(op)
+ ifm = Tensor([2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [3], DataType.int32, [2, 1, 0])
+ ofm = Tensor([6, 2, 2], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not support.is_operator_supported(op)
+ # Test not supported op IFM rank 4
+ ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [0, 1, 3, 2])
+ ofm = Tensor([1, 2, 6, 4], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not support.is_operator_supported(op)
+ ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [0, 3, 2, 1])
+ ofm = Tensor([1, 6, 4, 2], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not support.is_operator_supported(op)
+ ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [1, 0, 2, 3])
+ ofm = Tensor([2, 1, 4, 6], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not support.is_operator_supported(op)
+ ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [2, 1, 0, 3])
+ ofm = Tensor([4, 2, 1, 6], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not support.is_operator_supported(op)
+ ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [3, 1, 2, 0])
+ ofm = Tensor([6, 2, 4, 1], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not support.is_operator_supported(op)
+ ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm")
+ perm = create_const_tensor("perm", [4], DataType.int32, [3, 2, 1, 0])
+ ofm = Tensor([6, 4, 2, 1], DataType.int8, "ofm")
+ op = testutil.create_op(Op.Transpose, [ifm, perm], ofm)
+ assert not support.is_operator_supported(op)