aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 3872bdc8..35fc1a6f 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -550,25 +550,25 @@ def test_constraint_matching_quantization_parameters():
def test_constraint_elemwise_batch_size():
# BINARY CASE
- # Batch can be >1 if dims is <=2D
- op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [2, 2], [2, 2])
+ # Batch can be >1 if dims is <=3D
+ op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
assert support.is_operator_supported(op)
- # For dims >2D, batch must be 1
- op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2], [1, 2, 2], [1, 2, 2])
+ # For dims >3D, batch must be 1
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2])
assert support.is_operator_supported(op)
# invalid case
- op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
+ op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2])
assert not support.is_operator_supported(op)
# UNARY CASE
- # Batch can be >1 if dims is <=2D
- op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
+ # Batch can be >1 if dims is <=3D
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
assert support.is_operator_supported(op)
- # For dims >2D, batch must be 1
- op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2], None, [1, 2, 2], datatype=DataType.int32)
+ # For dims >3D, batch must be 1
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2, 2], None, [1, 2, 2, 2], datatype=DataType.int32)
assert support.is_operator_supported(op)
# invalid case
- op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2, 2], None, [2, 2, 2, 2], datatype=DataType.int32)
assert not support.is_operator_supported(op)