aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorErik Andersson <erik.andersson@arm.com>2020-12-10 14:58:23 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-22 15:02:51 +0000
commitf27a8b65f8cff8fc52db8e39a6eb8f78b6616c6b (patch)
treec153579837c3fc1d1803057998a42533e5320492
parentae2d553c4f3dd71a1df6c0e8c9cb920ae584b59e (diff)
downloadethos-u-vela-f27a8b65f8cff8fc52db8e39a6eb8f78b6616c6b.tar.gz
MLBEDSW-3711: Added operator checks for PAD.
Constraints and unit tests were added to check the new pad operator. Change-Id: Id6d4cf2c4da486928c8f46ba1fa124eec66895a6 Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com>
-rw-r--r--ethosu/vela/supported_operators.py69
-rw-r--r--ethosu/vela/test/test_supported_operators.py85
2 files changed, 152 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index f2c2eb9..2e35d77 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -82,6 +82,7 @@ class SupportedOperators:
binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
+ pad_ops = set((Op.Pad,))
supported_int32_tensor_ops = (
set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
)
@@ -101,10 +102,11 @@ class SupportedOperators:
shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
per_axis_quant_ops = convolution_like_ops # per-axis/channel quantization only currently supported for conv ops
supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
- supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
+ supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | pad_ops | npu_post_ops | memory_only_ops
# Supported data types
supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
supported_bias_dtypes = set((DataType.int32, DataType.int64))
+ supported_pad_dtypes = set((DataType.int32, DataType.int64))
# Defined ranges for allowed values:
tens_dim_range = (1, 65535)
stride_range = (1, 3)
@@ -115,6 +117,8 @@ class SupportedOperators:
filter_range = (1, 8)
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
+ # Supported consumers
+ supported_pad_consumers = convolution_ops | depthwise_convolution_ops
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -251,6 +255,16 @@ class SupportedOperators:
# FullyConnected specific checks:
self.specific_constraints[Op.FullyConnected].append(SupportedOperators.constraint_fc_output_2d)
+ # Pad specific checks:
+ self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_matching_in_out_types)
+ self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_matching_quantization_parameters)
+ self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_input_count)
+ self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_shape)
+ self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_padding_dimensions)
+ self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_type)
+ self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant)
+ self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in SupportedOperators.supported_operators:
@@ -769,6 +783,57 @@ class SupportedOperators:
return valid, f"Op has {inputs} inputs"
@staticmethod
+ def constraint_pad_input_count(op):
+ "Number of input tensors must be exactly 2"
+ inputs = len(op.inputs)
+ valid = inputs == 2
+ return valid, f"Op has {inputs} inputs"
+
+ @staticmethod
+ def constraint_pad_shape(op):
+ "The padding tensor must have the shape [4,2]"
+ valid = op.inputs[1].shape == [4, 2]
+ return valid, f"The pad tensor has the shape: {op.inputs[1].shape}"
+
+ @classmethod
+ @docstring_format_args([_list_formatter(supported_pad_dtypes)])
+ def constraint_pad_type(cls, op):
+ "Pad tensor must be of type: {}"
+ pad_tensor = op.inputs[1]
+ valid = pad_tensor.dtype in cls.supported_pad_dtypes
+ return valid, f"Tensor '{pad_tensor.name}' has data type: {pad_tensor.dtype}"
+
+ @staticmethod
+ def constraint_padding_dimensions(op):
+ "The pad tensor can only pad width and height"
+ pad_tensor = op.inputs[1].values
+ valid = sum(pad_tensor[0, :]) + sum(pad_tensor[-1, :]) == 0
+ return valid, f"First dimension padding: {pad_tensor[0,:]}, last dimension padding: {pad_tensor[-1,:]}"
+
+ @staticmethod
+ def constraint_pad_constant(op):
+ pad_tensor = op.inputs[1].values
+ valid = pad_tensor is not None
+ return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
+
+ @classmethod
+ @docstring_format_args([_optype_formatter(supported_pad_consumers)])
+ def constraint_pad_ofm(cls, op):
+ "Must be followed by one of the following operator types: {}"
+ consumers = op.ofm.consumers()
+ consumers_to_pad = 0
+ for consumer in consumers:
+ if consumer.type in cls.supported_pad_consumers:
+ if consumer.attrs["padding"] == Padding.VALID:
+ consumers_to_pad += 1
+ valid = len(consumers) > 0 and len(consumers) == consumers_to_pad
+ return (
+ valid,
+ f"Operator is followed by {consumers_to_pad} consumers with "
+ f"padding set to VALID, out of {len(consumers)} consumers",
+ )
+
+ @staticmethod
def constraint_stridedslice_inputs_const(op):
"Begin, End and Stride Input tensors must be constant"
valid = True
@@ -870,7 +935,7 @@ class SupportedOperators:
if not check_quantized_tens_scaling_equal(op.ofm, op.ifm):
valid = False
extra.append(op.ifm.name)
- if not check_quantized_tens_scaling_equal(op.ofm, op.ifm2):
+ if op.ifm2 is not None and not check_quantized_tens_scaling_equal(op.ofm, op.ifm2):
valid = False
extra.append(op.ifm2.name)
extra = ", ".join(extra)
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 973b820..f1e8f28 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -506,6 +506,89 @@ def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
+def create_pad_op(
+ in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32
+):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, in_dtype, "in")
+ in0.quantization = qp
+ pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+ out = Tensor(out_shape, out_dtype, "out")
+ out.quantization = qp.clone()
+ op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+
+ conv_out_tens = Tensor(in_shape, in_dtype, "output")
+ conv_out_tens.quantization = qp.clone()
+ weight_tens = Tensor(in_shape, in_dtype, "weights")
+ weight_tens.values = np.zeros(weight_tens.shape)
+ weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
+ weight_tens.quantization = qp.clone()
+ bias_tens = Tensor([in_shape[-1]], pad_dtype, "biases")
+ attrs = {"padding": Padding.VALID, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
+ attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+ conv2d_op = testutil.create_op(Op.Conv2D, [out, weight_tens, bias_tens], conv_out_tens, attrs)
+ conv2d_op.add_input_tensor(out)
+ conv2d_op.set_ifm_ofm_shapes()
+ return op
+
+
+def test_constraint_pad_input_count():
+ # Incorrect number of input tensors (2)
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
+ assert support.is_operator_supported(op)
+ op.add_input_tensor(op.inputs[0].clone())
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_padded_dimensions():
+ # Incorrect padding dimensions, can only pad width and height
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_pad_shape():
+ # PAD operator must be of shape (4,2)
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_pad_none():
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[],)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_pad_dtype():
+ # PAD operator dtype should be int32 or int64
+ op = create_pad_op(
+ in_shape=[1, 1, 1, 1],
+ out_shape=[1, 3, 3, 1],
+ padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],
+ pad_dtype=DataType.int16,
+ )
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_pad_consumer():
+ # PAD operator must be followed by a valid consumer with Padding.VALID attribute
+ op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],)
+ conv_op = op.ofm.consumers()[0]
+ conv_op.attrs["Padding"] = Padding.SAME
+ assert not support.is_operator_supported(op)
+ op_consumer = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
+ op.ofm.consumer_list = [op_consumer]
+ assert not support.is_operator_supported(op)
+ op_consumer = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op_consumer.attrs = {
+ "stride_w": 2,
+ "stride_h": 2,
+ "filter_width": 2,
+ "filter_height": 2,
+ "padding": Padding.VALID,
+ }
+ op.ofm.consumer_list = [op_consumer]
+ assert not support.is_operator_supported(op)
+
+
def create_strided_slice():
# Creates a valid strided slice operator with some valid inputs/outputs
op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
@@ -646,6 +729,8 @@ def test_constraint_matching_quantization_parameters():
# valid - all matching
op.ofm.quantization = qp
assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], None, [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
def test_constraint_elemwise_batch_size():