diff options
Diffstat (limited to 'ethosu/vela/test/test_tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/test/test_tflite_supported_operators.py | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py index 4aca00da..cbad1713 100644 --- a/ethosu/vela/test/test_tflite_supported_operators.py +++ b/ethosu/vela/test/test_tflite_supported_operators.py @@ -106,23 +106,24 @@ def test_constraint_conv_pass(): @pytest.mark.parametrize( - "stride_w, stride_h, supported", + "ifm_shape, stride_w, stride_h, supported", [ - [0, 20, False], - [20, 0, False], - [4, 3, True], - [4, 5, False], - [4, 9, False], - [3, 3, True], - [1, 1, True], - [20, 2, True], - [6, 3, True], - [8, 1, True], + [[1, 8, 8, 8], 0, 20, False], + [[1, 8, 8, 8], 20, 0, False], + [[1, 8, 8, 8], 4, 3, True], + [[1, 8, 8, 8], 4, 5, False], + [[1, 8, 8, 8], 4, 9, False], + [[1, 8, 8, 8], 3, 3, True], + [[1, 8, 8, 8], 1, 1, True], + [[1, 8, 8, 8], 20, 2, False], + [[1, 8, 40, 8], 20, 2, True], + [[1, 8, 40, 8], 6, 3, True], + [[1, 8, 40, 8], 8, 1, True], ], ) -def test_constraint_stride_range(stride_w: int, stride_h: int, supported: bool): +def test_constraint_stride_range(ifm_shape: list[int], stride_w: int, stride_h: int, supported: bool): # Stride width and height must lie within a certain range - op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8], [1, 1, 1, 1]) + op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, ifm_shape, [1, 8, 8, 8], [1, 1, 1, 1]) op.attrs = {"stride_w": stride_w, "stride_h": stride_h} assert support.is_operator_supported(op) == supported |