diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index a433fb8..e65717a 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -759,3 +759,86 @@ def test_constraint_slice_inputs_const(): op.set_input_tensor(begin, 1) op.set_input_tensor(begin, 2) assert support.is_operator_supported(op) + + +def test_constraint_transpose(): + # Test supported op IFM rank 2 + ifm = Tensor([2, 4], DataType.int8, "ifm") + perm = create_const_tensor("perm", [2], DataType.int32, [1, 0]) + ofm = Tensor([4, 2], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert support.is_operator_supported(op) + # Test supported op IFM rank 3 + ifm = Tensor([2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [3], DataType.int32, [1, 0, 2]) + ofm = Tensor([4, 2, 6], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert support.is_operator_supported(op) + ifm = Tensor([1, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [3], DataType.int32, [0, 2, 1]) + ofm = Tensor([1, 6, 4], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert support.is_operator_supported(op) + ifm = Tensor([2, 1, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [3], DataType.int32, [2, 1, 0]) + ofm = Tensor([6, 1, 2], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert support.is_operator_supported(op) + # Test supported op IFM rank 4 + ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [0, 2, 1, 3]) + ofm = Tensor([1, 4, 2, 6], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert support.is_operator_supported(op) + ifm = Tensor([1, 1, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [0, 1, 3, 2]) + ofm = Tensor([1, 1, 6, 4], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert support.is_operator_supported(op) + ifm = Tensor([1, 2, 1, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [0, 3, 2, 1]) + ofm = Tensor([1, 6, 1, 2], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert support.is_operator_supported(op) + # Test not supported op IFM rank 3 + ifm = Tensor([2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [3], DataType.int32, [0, 2, 1]) + ofm = Tensor([2, 6, 4], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not support.is_operator_supported(op) + ifm = Tensor([2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [3], DataType.int32, [2, 1, 0]) + ofm = Tensor([6, 2, 2], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not support.is_operator_supported(op) + # Test not supported op IFM rank 4 + ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [0, 1, 3, 2]) + ofm = Tensor([1, 2, 6, 4], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not support.is_operator_supported(op) + ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [0, 3, 2, 1]) + ofm = Tensor([1, 6, 4, 2], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not support.is_operator_supported(op) + ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [1, 0, 2, 3]) + ofm = Tensor([2, 1, 4, 6], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not support.is_operator_supported(op) + ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [2, 1, 0, 3]) + ofm = Tensor([4, 2, 1, 6], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not support.is_operator_supported(op) + ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [3, 1, 2, 0]) + ofm = Tensor([6, 2, 4, 1], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not support.is_operator_supported(op) + ifm = Tensor([1, 2, 4, 6], DataType.int8, "ifm") + perm = create_const_tensor("perm", [4], DataType.int32, [3, 2, 1, 0]) + ofm = Tensor([6, 4, 2, 1], DataType.int8, "ofm") + op = testutil.create_op(Op.Transpose, [ifm, perm], ofm) + assert not support.is_operator_supported(op) |