aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_tflite_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index f54211f..e5cc280 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -218,12 +218,47 @@ def test_constraint_depth_multiplier():
def test_constraint_tconv_stride():
- # Strides must be 2
+ # Valid 2x2
op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
+ ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert support.is_operator_supported(op)
+ # Valid 1x1
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1, "padding": Padding.SAME}
ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
ifm.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm)
+ assert support.is_operator_supported(op)
+ # Valid 2x1 (WxH) ifm h and kernel h = 1
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 2, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 2, "stride_h": 1, "padding": Padding.SAME}
+ ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert support.is_operator_supported(op)
+ # Invalid 2x1 (WxH) ifm h = 2 and kernel h = 1
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 2, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 2, "stride_h": 1, "padding": Padding.SAME}
+ ifm = Tensor([1, 2, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert not support.is_operator_supported(op)
+ # Invalid 2x1 (WxH) ifm h = 1 and kernel h = 2
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 1, 1], weights_shape=[1, 2, 1, 1])
+ op.attrs = {"stride_w": 2, "stride_h": 1, "padding": Padding.SAME}
+ ifm = Tensor([1, 2, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert not support.is_operator_supported(op)
+ # Invalid 1x2 (WxH)
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 1, "stride_h": 2, "padding": Padding.SAME}
+ ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
assert not support.is_operator_supported(op)