diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 3872bdc8..35fc1a6f 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -550,25 +550,25 @@ def test_constraint_matching_quantization_parameters(): def test_constraint_elemwise_batch_size(): # BINARY CASE - # Batch can be >1 if dims is <=2D - op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [2, 2], [2, 2]) + # Batch can be >1 if dims is <=3D + op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2]) assert support.is_operator_supported(op) - # For dims >2D, batch must be 1 - op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2], [1, 2, 2], [1, 2, 2]) + # For dims >3D, batch must be 1 + op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2]) assert support.is_operator_supported(op) # invalid case - op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2]) + op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]) assert not support.is_operator_supported(op) # UNARY CASE - # Batch can be >1 if dims is <=2D - op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32) + # Batch can be >1 if dims is <=3D + op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32) assert support.is_operator_supported(op) - # For dims >2D, batch must be 1 - op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2], None, [1, 2, 2], datatype=DataType.int32) + # For dims >3D, batch must be 1 + op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2, 2], None, [1, 2, 2, 2], datatype=DataType.int32) assert support.is_operator_supported(op) # invalid case - op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32) + op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2, 2], None, [2, 2, 2, 2], datatype=DataType.int32) assert not support.is_operator_supported(op) |