diff options
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r-- | ethosu/vela/test/test_supported_operators.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 973b820d..f1e8f28f 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -506,6 +506,89 @@ def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets): return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs) +def create_pad_op( + in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32 +): + qp = testutil.default_quant_params() + in0 = Tensor(in_shape, in_dtype, "in") + in0.quantization = qp + pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype) + out = Tensor(out_shape, out_dtype, "out") + out.quantization = qp.clone() + op = testutil.create_op(Op.Pad, [in0, pad_tensor], out) + + conv_out_tens = Tensor(in_shape, in_dtype, "output") + conv_out_tens.quantization = qp.clone() + weight_tens = Tensor(in_shape, in_dtype, "weights") + weight_tens.values = np.zeros(weight_tens.shape) + weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8) + weight_tens.quantization = qp.clone() + bias_tens = Tensor([in_shape[-1]], pad_dtype, "biases") + attrs = {"padding": Padding.VALID, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1} + attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1) + conv2d_op = testutil.create_op(Op.Conv2D, [out, weight_tens, bias_tens], conv_out_tens, attrs) + conv2d_op.add_input_tensor(out) + conv2d_op.set_ifm_ofm_shapes() + return op + + +def test_constraint_pad_input_count(): + # Incorrect number of input tensors (2) + op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],) + assert support.is_operator_supported(op) + op.add_input_tensor(op.inputs[0].clone()) + assert not support.is_operator_supported(op) + + +def test_constraint_padded_dimensions(): + # Incorrect padding dimensions, can only pad width and height + op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],) + assert not support.is_operator_supported(op) + + +def test_constraint_pad_shape(): + # PAD operator must be of shape (4,2) + op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],) + assert not support.is_operator_supported(op) + + +def test_constraint_pad_none(): + op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[],) + assert not support.is_operator_supported(op) + + +def test_constraint_pad_dtype(): + # PAD operator dtype should be int32 or int64 + op = create_pad_op( + in_shape=[1, 1, 1, 1], + out_shape=[1, 3, 3, 1], + padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]], + pad_dtype=DataType.int16, + ) + assert not support.is_operator_supported(op) + + +def test_constraint_pad_consumer(): + # PAD operator must be followed by a valid consumer with Padding.VALID attribute + op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],) + conv_op = op.ofm.consumers()[0] + conv_op.attrs["Padding"] = Padding.SAME + assert not support.is_operator_supported(op) + op_consumer = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8]) + op.ofm.consumer_list = [op_consumer] + assert not support.is_operator_supported(op) + op_consumer = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8]) + op_consumer.attrs = { + "stride_w": 2, + "stride_h": 2, + "filter_width": 2, + "filter_height": 2, + "padding": Padding.VALID, + } + op.ofm.consumer_list = [op_consumer] + assert not support.is_operator_supported(op) + + def create_strided_slice(): # Creates a valid strided slice operator with some valid inputs/outputs op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0]) @@ -646,6 +729,8 @@ def test_constraint_matching_quantization_parameters(): # valid - all matching op.ofm.quantization = qp assert support.is_operator_supported(op) + op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], None, [1, 8, 8, 8]) + assert support.is_operator_supported(op) def test_constraint_elemwise_batch_size(): |