aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_supported_operators.py')
-rw-r--r--ethosu/vela/test/test_supported_operators.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 973b820d..f1e8f28f 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -506,6 +506,89 @@ def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
+def create_pad_op(
+ in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32
+):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, in_dtype, "in")
+ in0.quantization = qp
+ pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+ out = Tensor(out_shape, out_dtype, "out")
+ out.quantization = qp.clone()
+ op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+
+ conv_out_tens = Tensor(in_shape, in_dtype, "output")
+ conv_out_tens.quantization = qp.clone()
+ weight_tens = Tensor(in_shape, in_dtype, "weights")
+ weight_tens.values = np.zeros(weight_tens.shape)
+ weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
+ weight_tens.quantization = qp.clone()
+ bias_tens = Tensor([in_shape[-1]], pad_dtype, "biases")
+ attrs = {"padding": Padding.VALID, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
+ attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+ conv2d_op = testutil.create_op(Op.Conv2D, [out, weight_tens, bias_tens], conv_out_tens, attrs)
+ conv2d_op.add_input_tensor(out)
+ conv2d_op.set_ifm_ofm_shapes()
+ return op
+
+
+def test_constraint_pad_input_count():
+ # Incorrect number of input tensors (2)
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
+ assert support.is_operator_supported(op)
+ op.add_input_tensor(op.inputs[0].clone())
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_padded_dimensions():
+ # Incorrect padding dimensions, can only pad width and height
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_pad_shape():
+ # PAD operator must be of shape (4,2)
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_pad_none():
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[],)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_pad_dtype():
+ # PAD operator dtype should be int32 or int64
+ op = create_pad_op(
+ in_shape=[1, 1, 1, 1],
+ out_shape=[1, 3, 3, 1],
+ padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],
+ pad_dtype=DataType.int16,
+ )
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_pad_consumer():
+ # PAD operator must be followed by a valid consumer with Padding.VALID attribute
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],)
+ conv_op = op.ofm.consumers()[0]
+ conv_op.attrs["Padding"] = Padding.SAME
+ assert not support.is_operator_supported(op)
+ op_consumer = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
+ op.ofm.consumer_list = [op_consumer]
+ assert not support.is_operator_supported(op)
+ op_consumer = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op_consumer.attrs = {
+ "stride_w": 2,
+ "stride_h": 2,
+ "filter_width": 2,
+ "filter_height": 2,
+ "padding": Padding.VALID,
+ }
+ op.ofm.consumer_list = [op_consumer]
+ assert not support.is_operator_supported(op)
+
+
def create_strided_slice():
# Creates a valid strided slice operator with some valid inputs/outputs
op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
@@ -646,6 +729,8 @@ def test_constraint_matching_quantization_parameters():
# valid - all matching
op.ofm.quantization = qp
assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], None, [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
def test_constraint_elemwise_batch_size():