aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-10-20 11:49:28 +0100
committerMichael McGeagh <michael.mcgeagh@arm.com>2020-11-04 14:11:24 +0000
commit65fd99830a762b2c59aaa446b55cbfa43a92f8ba (patch)
tree2320b8b0573234c7976d8228679f3b8f577b4590
parent37ce38c208601c6a7901d2dc266ed7db6842405b (diff)
downloadethos-u-vela-65fd99830a762b2c59aaa446b55cbfa43a92f8ba.tar.gz
MLBEDSW-2412 All constraints have been refactored
All existing constraints have now been refactored using the new framework. Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: Ic9ba0d7040cb9f114b959a949bfdf777f86752c7
-rw-r--r--ethosu/vela/operation.py42
-rw-r--r--ethosu/vela/supported_operators.py1003
-rw-r--r--ethosu/vela/test/test_supported_operators.py528
-rw-r--r--ethosu/vela/test/testutil.py32
4 files changed, 1047 insertions, 558 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index cc52ff4..1ba2a38 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -220,13 +220,13 @@ class Op(Enum):
Sin = OperatorInfo()
SkipGram = OperatorInfo()
Slice = OperatorInfo(indices=IFM_INDICES)
- Softmax = OperatorInfo()
+ Softmax = OperatorInfo(indices=IFM_INDICES)
SpaceToBatchND = OperatorInfo()
SpaceToDepth = OperatorInfo()
SparseToDense = OperatorInfo()
Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
- SplitV = OperatorInfo(indices=IFM_INDICES)
+ SplitV = OperatorInfo(indices=IFM_IFM2_INDICES)
Sqrt = OperatorInfo()
Square = OperatorInfo()
SquaredDifference = OperatorInfo()
@@ -399,19 +399,39 @@ class Operation:
__repr__ = __str__
- @property
- def kernel(self):
- strides = self.attrs.get("strides", (1, 1, 1, 1))
- dilation = self.attrs.get("dilation", (1, 1, 1, 1))
+ def get_kernel_size(self):
weights = self.weights
if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
weight_shape = full_shape(4, weights.shape, 1)
- k_h = weight_shape[-4]
- k_w = weight_shape[-3]
+ h = weight_shape[-4]
+ w = weight_shape[-3]
else:
- k_h = self.attrs.get("filter_height", 1)
- k_w = self.attrs.get("filter_width", 1)
- self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
+ h = self.attrs.get("filter_height", 1)
+ w = self.attrs.get("filter_width", 1)
+ return w, h
+
+ def get_kernel_stride(self):
+ if "strides" in self.attrs:
+ _, h, w, _ = self.attrs["strides"]
+ else:
+ h = self.attrs.get("stride_h", 1)
+ w = self.attrs.get("stride_w", 1)
+ return w, h
+
+ def get_kernel_dilation(self):
+ if "dilation" in self.attrs:
+ _, h, w, _ = self.attrs["dilation"]
+ else:
+ h = self.attrs.get("dilation_h_factor", 1)
+ w = self.attrs.get("dilation_w_factor", 1)
+ return w, h
+
+ @property
+ def kernel(self):
+ k_w, k_h = self.get_kernel_size()
+ s_w, s_h = self.get_kernel_stride()
+ d_w, d_h = self.get_kernel_dilation()
+ self._kernel = Kernel(k_w, k_h, s_w, s_h, d_w, d_h)
return self._kernel
def get_ifm_ifm2_weights_ofm(self):
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 24c7291..ddfb8ed 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -25,7 +25,6 @@ from .numeric_util import is_integer
from .operation import get_slice_offsets
from .operation import Op
from .tensor import check_quantized_tens_scaling_equal
-from .tensor import check_tens_quantized
# Custom decorator function to allow formatting docstrings containing "{}"
@@ -74,7 +73,8 @@ class SupportedOperators:
supported_int32_tensor_ops = (
set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
)
- activation_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Sigmoid, Op.Tanh, Op.Softmax,))
+ relu_ops = Op.op_set(Op.is_relu_op)
+ activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax,))
npu_post_ops = (
# activation functions
activation_ops
@@ -87,7 +87,7 @@ class SupportedOperators:
concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) | concat_ops | split_ops
shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
- supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
+ supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
# Supported data types
supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
@@ -99,39 +99,17 @@ class SupportedOperators:
dilated_height_range = (1, 64)
dilated_product_range = (1, 64 * 64)
weights_limit = 127 * 65536
+ filter_range = (1, 8)
+ filter_height_range = (1, 256)
+ filter_product_range = (1, 256 * 256)
def __init__(self):
- # Setup supported operator restriction checkers
- self.supported_operator_restrictions = {}
- self.supported_operator_restrictions.update(
- {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
- )
- self.supported_operator_restrictions.update(
- {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
- )
- self.supported_operator_restrictions.update(
- {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
- )
- self.supported_operator_restrictions.update(
- {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
- )
- self.supported_operator_restrictions.update(
- {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
- )
- self.supported_operator_restrictions.update(
- {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
- )
- self.supported_operator_restrictions.update(
- {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
- )
- self.supported_operator_restrictions.update(
- {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
- )
# Setup the generic constraints. Note: the order matters
self.generic_constraints = []
+ self.generic_constraints.append(SupportedOperators.constraint_tens_no_dynamic)
self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
- self.generic_constraints.append(SupportedOperators.constraint_tens_output_shapeless)
- self.generic_constraints.append(SupportedOperators.constraint_tens_input_shapeless)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_output_scalar)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_input_scalar)
self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
self.generic_constraints.append(SupportedOperators.constraint_tens_int32_ops)
@@ -139,76 +117,173 @@ class SupportedOperators:
self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
self.generic_constraints.append(SupportedOperators.constraint_faf)
- # Setup specific constraints. The key in the dictionary must be a tuple of op types the constraints apply to
+
+ # Setup specific constraints. Note: the order matters
self.specific_constraints = defaultdict(list)
- # Conv-like ops have the same checks applied to them:
- conv_like_ops = tuple(SupportedOperators.convolution_like_ops)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_type)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_range)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_type)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_range)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_height_range)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_product_range)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_type)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_nonconst)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_limit)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_type)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_40bit)
- self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_batch_size)
-
- def get_constraints_list(self, op_type):
- constraint_list = list(self.generic_constraints)
- for ops in self.specific_constraints:
- if op_type in ops:
- constraint_list.extend(self.specific_constraints[ops])
- return constraint_list
+
+ # Conv-like checks:
+ for op_type in SupportedOperators.convolution_like_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_range)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_dilation_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_dilation_range)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_dilated_height_range)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_dilated_product_range)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_const)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_limit)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_40bit)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_batch_size)
+ # Depthwise Conv specific checks:
+ for op_type in SupportedOperators.depthwise_convolution_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_depth_multiplier)
+ # Transpose Conv specific checks:
+ for op_type in SupportedOperators.transpose_convolution_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_stride)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_same)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_tconv_valid)
+
+ # Pooling checks:
+ for op_type in SupportedOperators.pooling_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_batch_size)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_stride_range)
+ # AVG pooling specific checks:
+ for op_type in SupportedOperators.avg_pooling_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_range)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_height_range_valid_pad)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_product_range_valid_pad)
+ # MAX pooling specific checks:
+ for op_type in SupportedOperators.max_pooling_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_height_range)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_filter_product_range)
+ # TODO: Check ReduceSum restrictions
+
+ # Relu specific checks:
+ for op_type in SupportedOperators.relu_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_quant_scale_inf)
+
+ # Resizing specific checks:
+ for op_type in SupportedOperators.resizing_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_resize)
+
+ # Vector Product specific checks:
+ for op_type in SupportedOperators.fc_vector_products:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_weights_const)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_type)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_bias_40bit)
+
+ # Concat specific checks:
+ for op_type in (Op.Concat, Op.ConcatTFLite):
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_axis_exists)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_axis_valid)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_dimensionality)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_valid_dimensions)
+
+ # Element-wise checks:
+ for op_type in SupportedOperators.elem_wise_main_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_elemwise_batch_size)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_either_shapes)
+ # Unary specific checks:
+ for op_type in SupportedOperators.unary_elem_wise_main_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+ # Binary Min/Max specific checks:
+ for op_type in SupportedOperators.binary_elem_wise_min_max_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_in_out_types)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_quantization_parameters)
+ # Binary Add/Mul/Sub specific checks:
+ for op_type in SupportedOperators.binary_elem_wise_add_mul_sub:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_inputs_types)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_matching_signed)
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_unsigned_valid)
+ # Binary Shift specific checks:
+ for op_type in SupportedOperators.binary_elem_wise_shift_ops:
+ self.specific_constraints[op_type].append(SupportedOperators.constraint_inputs_int32)
+
+ # SHL specific checks:
+ self.specific_constraints[Op.SHL].append(SupportedOperators.constraint_output_int32)
+
+ # CLZ specific checks:
+ self.specific_constraints[Op.CLZ].append(SupportedOperators.constraint_output_int32)
+
+ # Softmax specific checks:
+ self.specific_constraints[Op.Softmax].append(SupportedOperators.constraint_matching_shapes)
+ self.specific_constraints[Op.Softmax].append(SupportedOperators.constraint_matching_in_out_types)
+
+ # SplitV specific checks:
+ self.specific_constraints[Op.SplitV].append(SupportedOperators.constraint_splitv_inferred)
+
+ # StridedSlice specific checks:
+ self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_input_count)
+ self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_inputs_const)
+ self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_tens_size_matches)
+ self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_stride_values)
+ self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_ellipsis_mask)
+ self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_axis_masks)
+ self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_slice_ranges)
+
+ # LeakyRelu specific checks:
+ self.specific_constraints[Op.LeakyRelu].append(SupportedOperators.constraint_alpha_valid)
def is_operator_supported(self, op):
if op.type not in SupportedOperators.supported_operators:
if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
- print("Info: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
+ print(f"Info: {op.type} '{op.name}' is not supported on the NPU. Placing on CPU instead")
return False
- for constraint in self.get_constraints_list(op.type):
+ for constraint in self.generic_constraints + self.specific_constraints[op.type]:
valid, extra = constraint(op)
if not valid:
- print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
- print(" - {}".format(constraint.__doc__))
+ print(f"Warning: {op.type} '{op.name}' is not supported on the NPU. Placing on CPU instead")
+ print(f" - {constraint.__doc__}")
if extra:
- print(" {}".format(extra))
+ print(f" {extra}")
return False
- if op.type in self.supported_operator_restrictions:
- return self.supported_operator_restrictions[op.type](op)
return True
@staticmethod
- def constraint_tens_defined_shape(op):
- "Input(s) and Output Tensors must have a defined shape"
+ def constraint_tens_no_dynamic(op):
+ "Input(s) and Output tensors must not be dynamic"
valid = True
extra = []
tensors = [tens for tens in op.inputs + op.outputs if tens]
for tens in tensors:
- if not tens.has_fully_defined_shape():
+ if (tens.shape == []) and (tens.values is None):
valid = False
- extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
- return valid, ", ".join(extra)
+ extra.append(tens.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has dynamic tensor(s): {extra}"
@staticmethod
- def constraint_tens_output_shapeless(op):
- "Scalar or Broadcasting Tensors are only valid for Input Tensors"
+ def constraint_tens_defined_shape(op):
+ "Input(s) and Output tensors must have a defined shape"
valid = True
extra = []
- for tens in op.outputs:
- if tens.shape == []:
+ tensors = [tens for tens in op.inputs + op.outputs if tens]
+ for tens in tensors:
+ if not tens.has_fully_defined_shape():
valid = False
- extra.append("Output Tensor '{}' is shapeless".format(tens.name))
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
return valid, ", ".join(extra)
+ @staticmethod
+ def constraint_tens_output_scalar(op):
+ "Output tensors cannot be scalar"
+ ofm = op.ofm
+ valid = ofm.shape != []
+ return valid, f"Output Tensor '{ofm.name}' is scalar"
+
@classmethod
@docstring_format_args([shapeless_input_ops])
- def constraint_tens_input_shapeless(cls, op):
- "Scalar or Broadcasting Input Tensors are only valid for op type: {}"
+ def constraint_tens_input_scalar(cls, op):
+ "Scalar Input tensors are only valid for op type: {}"
valid = True
extra = []
tensors = [tens for tens in op.inputs if tens]
@@ -216,33 +291,34 @@ class SupportedOperators:
if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
valid = False
extra.append(tens.name)
- extra = "Op has shapeless input tensor(s): {}".format(", ".join(extra))
- return valid, extra
+ extra = ", ".join(extra)
+ return valid, f"Op has scalar input tensor(s): {extra}"
@staticmethod
def constraint_tens_shape_size(op):
- "Input(s) and Output Tensors must not be greater than 4D"
+ "Input(s) and Output tensors must not be greater than 4D"
valid = True
extra = []
tensors = [tens for tens in op.inputs + op.outputs if tens]
for tens in tensors:
if len(tens.shape) > 4:
valid = False
- extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
return valid, ", ".join(extra)
@classmethod
@docstring_format_args([supported_op_dtypes])
def constraint_tens_dtype(cls, op):
- "Input(s), Output and Weight Tensors must be of type: {}"
+ "Tensors must be of type: {}"
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- tensors = tensors if tensors else op.inputs
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
for tens in tensors:
if tens.dtype not in cls.supported_op_dtypes:
valid = False
- extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
+ extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}")
return valid, ", ".join(extra)
@classmethod
@@ -252,13 +328,14 @@ class SupportedOperators:
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- tensors = tensors if tensors else op.inputs
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
for tens in tensors:
if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
valid = False
extra.append(tens.name)
- extra = "Op has int32 tensor(s): {}".format(", ".join(extra))
- return valid, extra
+ extra = ", ".join(extra)
+ return valid, f"Op has int32 tensor(s): {extra}"
@classmethod
@docstring_format_args(tens_dim_range)
@@ -268,35 +345,37 @@ class SupportedOperators:
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
- tensors = tensors if tensors else op.inputs
+ if not tensors:
+ tensors = [tens for tens in op.inputs if tens]
for tens in tensors:
if not all(tens_min <= dim <= tens_max for dim in tens.shape):
valid = False
- extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
return valid, ", ".join(extra)
@staticmethod
def constraint_tens_quant_none_check(op):
- "Tensors must have quantization parameters"
+ "Input(s), Output and Weight tensors must have quantization parameters"
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
for tens in tensors:
if tens.quantization is None:
valid = False
- extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
- return valid, ", ".join(extra)
+ extra.append(tens.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has tensors with missing quantization parameters: {extra}"
@staticmethod
def constraint_tens_quant_scale(op):
- "Tensors with quantization scales must be finite"
+ "Input(s), Output and Weight tensors with quantization scales must be finite"
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
for tens in tensors:
if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
valid = False
- extra.append("Tensor '{}' has quantization scale: {}".format(tens.name, tens.quantization.scale_f32))
+ extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
return valid, ", ".join(extra)
@classmethod
@@ -305,87 +384,71 @@ class SupportedOperators:
"The fused activation function (if present) must be one of type: {}"
faf = op.activation
valid = (faf is None) or (faf in cls.supported_fused_activations)
- extra = "Op has its fused activation function as: {}".format(faf)
- return valid, extra
+ return valid, f"Op has its fused activation function as: {faf}"
@staticmethod
def constraint_stride_type(op):
"Stride values for both width and height must be integer types"
- w = op.attrs["stride_w"]
- h = op.attrs["stride_h"]
+ w, h = op.get_kernel_stride()
valid = is_integer(w) and is_integer(h)
- extra = "Op has stride WxH as: {}x{}".format(repr(w), repr(h))
- return valid, extra
+ return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
@classmethod
@docstring_format_args(stride_range)
def constraint_stride_range(cls, op):
"Stride values for both width and height must be in the range [{}, {}]"
- w = op.attrs["stride_w"]
- h = op.attrs["stride_h"]
+ w, h = op.get_kernel_stride()
stride_min, stride_max = cls.stride_range
valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max)
- extra = "Op has stride WxH as: {}x{}".format(w, h)
- return valid, extra
+ return valid, f"Op has stride WxH as: {w}x{h}"
@staticmethod
def constraint_dilation_type(op):
"Dilation factor values for both width and height must be integer types"
- w = op.attrs.get("dilation_w_factor", 1)
- h = op.attrs.get("dilation_h_factor", 1)
+ w, h = op.get_kernel_dilation()
valid = is_integer(w) and is_integer(h)
- extra = "Op has dilation factor WxH as: {}x{}".format(repr(w), repr(h))
- return valid, extra
+ return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
@classmethod
@docstring_format_args(dilation_range)
def constraint_dilation_range(cls, op):
"Dilation factor values for both width and height must be in the range [{}, {}]"
- w = op.attrs.get("dilation_w_factor", 1)
- h = op.attrs.get("dilation_h_factor", 1)
+ w, h = op.get_kernel_dilation()
dilation_min, dilation_max = cls.dilation_range
valid = (dilation_min <= w <= dilation_max) and (dilation_min <= h <= dilation_max)
- extra = "Op has dilation factor WxH as: {}x{}".format(w, h)
- return valid, extra
+ return valid, f"Op has dilation factor WxH as: {w}x{h}"
@classmethod
@docstring_format_args(dilated_height_range)
def constraint_dilated_height_range(cls, op):
"Dilated kernel height must be in the range [{}, {}]"
- h = (op.weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
+ h = op.kernel.area_height()
dilated_height_min, dilated_height_max = cls.dilated_height_range
valid = dilated_height_min <= h <= dilated_height_max
- extra = "Op has dilated kernel height as: {}".format(h)
- return valid, extra
+ return valid, f"Op has dilated kernel height as: {h}"
@classmethod
@docstring_format_args(dilated_product_range)
def constraint_dilated_product_range(cls, op):
"Product of dilated kernel width and height must be in the range [{}, {}]"
- weights = op.weights
- w = (weights.shape[1] - 1) * op.attrs.get("dilation_w_factor", 1) + 1
- h = (weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
- product = w * h
+ product = op.kernel.area_width() * op.kernel.area_height()
dilated_product_min, dilated_product_max = cls.dilated_product_range
valid = dilated_product_min <= product <= dilated_product_max
- extra = "Op has product of dilated kernel width and height as: {}".format(product)
- return valid, extra
+ return valid, f"Op has product of dilated kernel width and height as: {product}"
@staticmethod
def constraint_weights_type(op):
- "Weight Tensor must be 8-bit"
+ "Weight tensor must be 8-bit"
weights = op.weights
valid = weights.element_size() == 1
- extra = "Tensor '{}' is {}-bit".format(weights.name, int(weights.element_size() * 8))
- return valid, extra
+ return valid, f"Tensor '{weights.name}' is {int(weights.element_size() * 8)}-bit"
@staticmethod
- def constraint_weights_nonconst(op):
- "Weight tensor cannot be non-constant"
+ def constraint_weights_const(op):
+ "Weight tensor must be constant"
weights = op.weights
valid = weights.values is not None
- extra = "Tensor '{}' has non-constant values".format(weights.name)
- return valid, extra
+ return valid, f"Tensor '{weights.name}' has non-constant values"
@classmethod
@docstring_format_args([weights_limit])
@@ -395,405 +458,409 @@ class SupportedOperators:
values = weights.quant_values.astype(np.int64) - weights.quantization.zero_point
limit = np.amax(np.sum(np.absolute(values), axis=(0, 1, 2)))
valid = limit <= cls.weights_limit
- extra = "Tensor '{}' has the sum of weights: {}".format(weights.name, limit)
- return valid, extra
+ return valid, f"Tensor '{weights.name}' has the sum of weights: {limit}"
@classmethod
@docstring_format_args([supported_bias_dtypes])
def constraint_bias_type(cls, op):
- "Optional Bias Tensor must be of type: {}"
- valid = True
- extra = ""
+ "Optional Bias tensor must be of type: {}"
bias = op.bias
if bias:
valid = bias.dtype in cls.supported_bias_dtypes
- extra = "Tensor '{}' has data type: {}".format(bias.name, bias.dtype)
- return valid, extra
+ return valid, f"Tensor '{bias.name}' has data type: {bias.dtype}"
+ return True, "Op has no bias tensor"
@staticmethod
def constraint_bias_40bit(op):
- "Optional Bias Tensor values must fit within 40-bits"
- valid = True
- extra = ""
+ "Optional Bias tensor values must fit within 40-bits"
bias = op.bias
if bias and bias.dtype == DataType.int64:
valid = all(len(bin(quant_value)[2:]) <= 40 for quant_value in bias.quant_values)
- extra = "Tensor '{}' has values larger than 40-bits".format(bias.name)
- return valid, extra
+ return valid, f"Tensor '{bias.name}' has values larger than 40-bits"
+ return True, "Op has no bias tensor, or it fits in 40-bit"
@staticmethod
def constraint_batch_size(op):
"IFM Tensor batch size must be 1"
ifm = op.ifm
valid = ifm.shape[0] == 1
- extra = "Tensor '{}' has batch size: {}".format(ifm.name, ifm.shape[0])
- return valid, extra
+ return valid, f"Tensor '{ifm.name}' has batch size: {ifm.shape[0]}"
- @classmethod
- def check_depthwise_convolution_restrictions(cls, op):
- # check depth
- ifm_tensor, ofm_tensor = op.get_ifm_ofm()
- if op.attrs["depth_multiplier"] > 1 and not (
- (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
- ):
- print(
- "Warning: for depth multipliers > 1,",
- "number of input channels must be 1 and number of output channels must be equal to depth multiplier.",
- "Placing on CPU",
+ @staticmethod
+ def constraint_quant_scale_inf(op):
+ "The IFM quantization scale divided by the OFM quantization scale must not be infinite"
+ ifm_scale = op.ifm.quantization.scale_f32
+ ofm_scale = op.ofm.quantization.scale_f32
+ valid = not np.isinf(ifm_scale / ofm_scale)
+ return valid, f"Op has infinite quantization scale. ifm_scale={ifm_scale} ofm_scale={ofm_scale}"
+
+ @staticmethod
+ def constraint_depth_multiplier(op):
+ "For depth multipliers > 1, IFM channels must be 1 and OFM channels must be equal to the depth multiplier"
+ depth_multiplier = op.attrs.get("depth_multiplier", 1)
+ if depth_multiplier > 1:
+ ifm_channels = op.ifm.shape[3]
+ ofm_channels = op.ofm.shape[3]
+ valid = (ifm_channels == 1) and (ofm_channels == depth_multiplier)
+ extra = (
+ f"Op has ifm_channels={ifm_channels}, ofm_channels={ofm_channels}"
+ f" and depth_multiplier={depth_multiplier}"
)
- return False
- return True
+ return valid, extra
+ return True, "Op has depth_multiplier=1"
- @classmethod
- def check_transpose_convolution_restrictions(cls, op):
- # check stride
- stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
- if stride_h != 2 or stride_w != 2:
- print("Warning: stride must be equal to 2, placing on CPU")
- return False
+ @staticmethod
+ def constraint_tconv_stride(op):
+ "Stride values for both width and height must be 2"
+ w = op.kernel.stride.x
+ h = op.kernel.stride.y
+ valid = (w == 2) and (h == 2)
+ return valid, f"Op has stride WxH as: {w}x{h}"
- # check output dimensions
- ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
- ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
- ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
+ @staticmethod
+ def constraint_tconv_same(op):
+ "SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride"
if op.attrs["padding"] == b"SAME":
- if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
- print(
- "Warning: for",
- op.type,
- "using SAME padding, output dimensions must equal input dimensions multiplied by stride.",
- "Placing on CPU",
- )
- return False
- elif op.attrs["padding"] == b"VALID":
- kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
- if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
- ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
- ):
- print(
- "Warning: for",
- op.type,
- "using VALID padding, output dimensions must equal input dimensions multiplied by stride,",
- "minus difference between kernel size and stride. Placing on CPU",
- )
- return False
- return True
+ w = op.kernel.stride.x
+ h = op.kernel.stride.y
+ ifm_shape = op.ifm.shape
+ ofm_shape = op.ofm.shape
+ valid = (ofm_shape[1] == (ifm_shape[1] * h)) and (ofm_shape[2] == (ifm_shape[2] * w))
+ return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and stride WxH as {w}x{h}"
+ return True, "Op has padding=VALID"
- @classmethod
- def check_pooling_restrictions(cls, op):
- # check stride
- stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
- if not is_integer(stride_w) or not is_integer(stride_h):
- print("Warning:", op.type, "has non-integer stride, placing on CPU")
- return False
- if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
- print(
- "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
- op.type, stride_w, stride_h
- )
+ @staticmethod
+ def constraint_tconv_valid(op):
+ """VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,
+ minus difference between kernel size and stride"""
+ if op.attrs["padding"] == b"VALID":
+ s_w = op.kernel.stride.x
+ s_h = op.kernel.stride.y
+ k_w = op.kernel.width
+ k_h = op.kernel.height
+ ifm_shape = op.ifm.shape
+ ofm_shape = op.ofm.shape
+ height_check = ofm_shape[1] == (ifm_shape[1] * s_h + max(k_h - s_h, 0))
+ width_check = ofm_shape[2] == (ifm_shape[2] * s_w + max(k_w - s_w, 0))
+ valid = height_check and width_check
+ extra = (
+ f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape},"
+ f" stride WxH as {s_w}x{s_h} and kernel WxH as {k_w}x{k_h}"
)
- return False
+ return valid, extra
+ return True, "Op has padding=SAME"
- # check data type
- ifm_tensor, ofm_tensor = op.get_ifm_ofm()
- if ifm_tensor.dtype != ofm_tensor.dtype:
- if op.type != Op.ReduceSum:
- print("Warning: input data type doesn't match output data type, placing on CPU")
- return False
- # TODO: else check ReduceSum restrictions.
-
- # check batch size
- if ifm_tensor.shape[0] != 1:
- print("Warning: input batch size must be 1, placing on CPU")
- return False
+ @staticmethod
+ def constraint_matching_in_out_types(op):
+ "IFM and OFM data types must match"
+ ifm_dtype = op.ifm.dtype
+ ofm_dtype = op.ofm.dtype
+ valid = ifm_dtype == ofm_dtype
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
- # check kernel size
- kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"]
- if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME":
- if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8:
- print(
- "Warning:",
- op.type,
- "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format(
- kernel_w, kernel_h
- ),
- )
- return False
- if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops:
- if not 1 <= kernel_w * kernel_h <= 256 * 256:
- print(
- "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format(
- 256 * 256
- ),
- "placing on CPU",
- )
- return False
- if not 1 <= kernel_h <= 256:
- print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU")
- return False
+ @staticmethod
+ def constraint_filter_type(op):
+ "Kernel filter values for both width and height must be integer types"
+ w = op.kernel.width
+ h = op.kernel.height
+ valid = is_integer(w) and is_integer(h)
+ return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
- return True
+ @classmethod
+ @docstring_format_args(filter_range)
+ def constraint_filter_range(cls, op):
+ "Kernel filter values for both width and height must be in the range [{}, {}]"
+ if op.attrs["padding"] == b"SAME":
+ w = op.kernel.width
+ h = op.kernel.height
+ filter_min, filter_max = cls.filter_range
+ valid = (filter_min <= w <= filter_max) and (filter_min <= h <= filter_max)
+ return valid, f"Op has kernel filter WxH as: {w}x{h}"
+ return True, "Op has padding=VALID"
@classmethod
- def check_resize_restrictions(cls, op):
- # check unsupported upscaling factor
- if op.type == Op.ResizeBilinear:
- if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
- return True
- if op.inputs[0].shape == op.outputs[0].shape:
- return True
- upscaled_shape = np.array(op.inputs[0].shape[1:3])
- out_shape = np.array(op.outputs[0].shape[1:3])
- while (upscaled_shape < out_shape).all():
- upscaled_shape *= 2
- if op.attrs["align_corners"]:
- upscaled_shape -= 1
- if np.array_equal(out_shape, upscaled_shape):
- return True
- return False
+ @docstring_format_args(filter_height_range)
+ def constraint_filter_height_range(cls, op):
+ "Kernel filter height must be in the range [{}, {}]"
+ h = op.kernel.height
+ filter_height_min, filter_height_max = cls.filter_height_range
+ valid = filter_height_min <= h <= filter_height_max
+ return valid, f"Op has kernel filter height as: {h}"
@classmethod
- def check_vector_product_restrictions(cls, op):
- # check data type
- ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
- if weight_tensor.element_size() > 1:
- print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type))
- return False
+ @docstring_format_args(filter_product_range)
+ def constraint_filter_product_range(cls, op):
+ "Product of kernel filter width and height must be in the range [{}, {}]"
+ product = op.kernel.elements_wh()
+ filter_product_min, filter_product_max = cls.filter_product_range
+ valid = filter_product_min <= product <= filter_product_max
+ return valid, f"Op has product of kernel filter width and height as: {product}"
- if not cls.check_bias_restrictions(bias_tensor):
- return False
+ @staticmethod
+ @docstring_format_args(filter_height_range)
+ def constraint_filter_height_range_valid_pad(op):
+ "VALID padding: Kernel filter height must be in the range [{}, {}]"
+ if op.attrs["padding"] == b"VALID":
+ return SupportedOperators.constraint_filter_height_range(op)
+ return True, "Op has padding=SAME"
- # check non const weights
- if weight_tensor.values is None:
- print("Warning:", op.type, "has non-const weights, placing on CPU")
- return False
+ @staticmethod
+ @docstring_format_args(filter_product_range)
+ def constraint_filter_product_range_valid_pad(op):
+ "VALID padding: Product of kernel filter width and height must be in the range [{}, {}]"
+ if op.attrs["padding"] == b"VALID":
+ return SupportedOperators.constraint_filter_product_range(op)
+ return True, "Op has padding=SAME"
- return True
+ @staticmethod
+ def constraint_resize(op):
+ """The width and height of the IFM and OFM must match one of the following criteria:
+ IFM W and H must both be 1
+ IFM must match OFM
+ OFM W and H must be 2x IFM -1, if align_corners is True
+ OFM W and H must be 2x IFM, if align_corners is False"""
+ # Easier to start with False condition as very few cases result in a supported resize
+ valid = False
+ ifm_shape = op.ifm.shape
+ ofm_shape = op.ofm.shape
+ align_corners = op.attrs.get("align_corners", False)
+ if len(ifm_shape) == 4:
+ # Valid if IFM W and H are both 1, or IFM and OFM shape are the same
+ if ((ifm_shape[1] == 1) and (ifm_shape[2] == 1)) or (ifm_shape == ofm_shape):
+ valid = True
+ else:
+ upscaled_shape = np.array(ifm_shape[1:3])
+ out_shape = np.array(ofm_shape[1:3])
+ while (upscaled_shape < out_shape).all():
+ upscaled_shape *= 2
+ if align_corners:
+ upscaled_shape -= 1
+ # Valid if OFM is 2x IFM (-1 for align corners)
+ if np.array_equal(out_shape, upscaled_shape):
+ valid = True
+ break
+ return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and align_corners={align_corners}"
- @classmethod
- def check_element_wise_restrictions(cls, op):
- # check data type
- ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
- # input and output datatype must match for these operators
- if (
- op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
- and ifm_tensor.dtype != ofm_tensor.dtype
- ):
- print("Warning:", op.type, "must have same input and output datatype, placing on CPU")
- return False
- if op.type in cls.binary_elem_wise_add_mul_sub:
- # both inputs must have same type
- if ifm_tensor.dtype != ifm2_tensor.dtype:
- print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU")
- return False
- # signed input check
- if ifm_tensor.dtype.type & BaseType.Signed:
- # output must be signed
- if ofm_tensor.dtype.type & BaseType.Unsigned:
- print("Warning: only signed output types supported for {}, placing on CPU".format(op.type))
- return False
- # and 8, 16 or 32-bit
- bit_lengths = {8, 16, 32}
- if ofm_tensor.element_size() * 8 not in bit_lengths:
- print(
- "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths)
- )
- return False
- # unsigned input check, output must be same type or int32
- if ifm_tensor.dtype.type & BaseType.Unsigned and not (
- ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
- ):
- print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU")
- return False
- elif op.type in cls.binary_elem_wise_shift_ops:
- if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
- print("Warning:", op.type, "input datatypes are not int32, placing on CPU")
- return False
- if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
- print("Warning:", op.type, "output datatype is not int32, placing on CPU")
- return False
+ @staticmethod
+ def constraint_matching_shapes(op):
+ "IFM and OFM shapes must match"
+ ifm_shape = op.ifm.shape
+ ofm_shape = op.ofm.shape
+ valid = ifm_shape == ofm_shape
+ return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
- # check batch size
- if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
- print(
- "Warning:",
- op.type,
- "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
- )
- return False
- if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
- if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
- print(
- "Warning:",
- op.type,
- "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
- )
- return False
+ @staticmethod
+ def constraint_splitv_inferred(op):
+ "Only one size is allowed to be inferred"
+ sizes = op.ifm2.values
+ valid = np.count_nonzero(sizes == -1) <= 1
+ return valid, f"Op has multiple inferred sizes (-1): {sizes}"
- # negative alpha values are not supported
- if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
- print("Warning:", op.type, "has negative alpha, placing on CPU")
- return False
+ @staticmethod
+ def constraint_axis_exists(op):
+ "Axis attribute must exist"
+ axis = op.attrs.get("axis")
+ valid = axis is not None
+ return valid, f"Op has axis={axis}"
- # check if ifm or ifm2 has ofm shape
- if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
- print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU")
- return False
+ @staticmethod
+ def constraint_axis_valid(op):
+ "Axis attribute must be in the range [0, <ofm_dimensions>)"
+ dims = len(op.ofm.shape)
+ axis = op.attrs["axis"]
+ axis += dims if axis < 0 else 0
+ valid = 0 <= axis < dims
+ return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
- if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
- return False
+ @staticmethod
+ def constraint_matching_dimensionality(op):
+ "All Input dimensionalities must match OFM dimensionality"
+ valid = True
+ extra = []
+ ofm_dim = len(op.ofm.shape)
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ dim = len(tens.shape)
+ if dim != ofm_dim:
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
+ extra = ", ".join(extra)
+ return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
- return True
+ @staticmethod
+ def constraint_valid_dimensions(op):
+ "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
+ valid = True
+ extra = []
+ ofm_shape = op.ofm.shape
+ ofm_dim = len(ofm_shape)
+ axis = op.attrs["axis"]
+ axis += ofm_dim if axis < 0 else 0
+ tensors = [tens for tens in op.inputs if tens]
+ for tens in tensors:
+ if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
+ valid = False
+ extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+ extra = ", ".join(extra)
+ return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
- @classmethod
- def check_memory_only_restrictions(cls, op):
- if op.type == Op.StridedSlice:
- if len(op.inputs) != 4:
- warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
- return False
- input_tens, begin_tens, end_tens, strides_tens = op.inputs
- if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
- warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
- return False
- if not (
- len(input_tens.shape)
- == len(op.outputs[0].shape)
- == len(begin_tens.values)
- == len(end_tens.values)
- == len(strides_tens.values)
- ):
- warn_cpu(op, "has input tensors with shapes that are not supported")
- return False
- # check stride size
- if any(stride != 1 for stride in strides_tens.values):
- warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
- return False
- # check ellipsis_mask
- if op.attrs["ellipsis_mask"] != 0:
- warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
- return False
- # check if both new_axis_mask and shrink_axis_mask have bit set
- if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
- warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
- return False
- # Calculate offset start/end
- offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
- offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
- # check "end - begin" doesn't result in any zero or negative elements
- if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
- warn_cpu(
- op,
- "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
- begin_tens.values, end_tens.values
- ),
- )
- return False
- if op.type == Op.SplitV:
- # check that maximum one size is set to -1, indicating that size should be inferred
- sizes = op.inputs[1].values
- num_to_be_inferred = 0
- for size in sizes:
- if size == -1:
- num_to_be_inferred += 1
-
- if num_to_be_inferred > 1:
- print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
- return False
- if op.type in set((Op.Concat, Op.ConcatTFLite,)):
- axis = op.attrs.get("axis", None)
- if axis is None:
- print("Warning:", op.type, "invalid or missing axis, placing on CPU")
- return False
- if axis < 0:
- axis += len(op.inputs[0].shape)
- if not 0 <= axis < len(op.inputs[0].shape):
- print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
- return False
- ofm = op.outputs[0]
- ofm_dims = len(ofm.shape)
- for ifm in op.inputs:
- if len(ifm.shape) != ofm_dims:
- return False
- for i in range(ofm_dims):
- if i != axis and ifm.shape[i] != ofm.shape[i]:
- print(
- "Warning:",
- op.type,
- "invalid ifm:",
- ifm.name,
- ifm.shape,
- "mismatch in dimension",
- i,
- ", placing on CPU",
- )
- return False
+ @staticmethod
+ def constraint_stridedslice_input_count(op):
+ "Exactly 4 Input tensors are required"
+ inputs = len(op.inputs)
+ valid = inputs == 4
+ return valid, f"Op has {inputs} inputs"
- return True
+ @staticmethod
+ def constraint_stridedslice_inputs_const(op):
+ "Begin, End and Stride Input tensors must be constant"
+ valid = True
+ extra = []
+ _, begin, end, strides = op.inputs
+ if begin.values is None:
+ valid = False
+ extra.append(f"Begin tensor '{begin.name}'")
+ if end.values is None:
+ valid = False
+ extra.append(f"End tensor '{end.name}'")
+ if strides.values is None:
+ valid = False
+ extra.append(f"Stride tensor '{strides.name}'")
+ extra = ", ".join(extra)
+ return valid, f"Op has non-constant tensors: {extra}"
- @classmethod
- def check_quantization_restrictions_binary_elem_wise(cls, op):
- # checks that IFM1, IFM2 and OFM quantization are equal for binary ops
+ @staticmethod
+ def constraint_stridedslice_tens_size_matches(op):
+ "All Input sizes must match OFM size"
+ ifm, begin, end, strides = op.inputs
+ ifm_size = len(ifm.shape)
+ ofm_size = len(op.ofm.shape)
+ begin_size = len(begin.values)
+ end_size = len(end.values)
+ strides_size = len(strides.values)
+ valid = ifm_size == ofm_size == begin_size == end_size == strides_size
+ extra = (
+ f"Op has ofm_size={ofm_size}, ifm_size={ifm_size},"
+ f" begin_size={begin_size}, end_size={end_size} and strides_size={strides_size}"
+ )
+ return valid, extra
- assert len(op.inputs) >= 2 and len(op.outputs) == 1
+ @staticmethod
+ def constraint_stridedslice_stride_values(op):
+ "All Strides values must be 1"
+ strides = op.inputs[3]
+ valid = all(stride == 1 for stride in strides.values)
+ return valid, f"Op has strides values {strides.values}"
- if (
- not check_tens_quantized(op.inputs[0])
- or not check_tens_quantized(op.inputs[1])
- or not check_tens_quantized(op.outputs[0])
- ):
- warn_cpu(op, "has non-quantised input and/or output tensors")
- return False
+ @staticmethod
+ def constraint_ellipsis_mask(op):
+ "ellipsis_mask must be 0"
+ ellipsis = op.attrs["ellipsis_mask"]
+ valid = ellipsis == 0
+ return valid, f"Op has ellipsis mask as: {ellipsis}"
- if not check_quantized_tens_scaling_equal(op.inputs[0], op.inputs[1]) or not check_quantized_tens_scaling_equal(
- op.inputs[0], op.outputs[0]
- ):
- warn_cpu(op, "has input/output tensors with different quantisation which is illegal")
- return False
+ @staticmethod
+ def constraint_axis_masks(op):
+ "new_axis_mask and shrink_axis_mask cannot both be set"
+ new_axis = op.attrs["new_axis_mask"]
+ shrink_axis = op.attrs["shrink_axis_mask"]
+ valid = (new_axis == 0) or (shrink_axis == 0)
+ return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
- return True
+ @staticmethod
+ def constraint_slice_ranges(op):
+ "Slice 'end' values must be greater than 'begin' values"
+ ifm, begin, end, _ = op.inputs
+ # Calculate offset begin/end
+ offset_begin = get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
+ offset_end = get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
+ # Check "end - begin" doesn't result in any zero or negative elements
+ valid = all((e - b) > 0 for b, e in zip(offset_begin, offset_end))
+ return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
- @classmethod
- def check_activation_ops(cls, op):
- if op.type == Op.Softmax:
- ifm_tensor = op.inputs[0]
- ofm_tensor = op.outputs[0]
-
- # check data type
- if ifm_tensor.dtype != ofm_tensor.dtype:
- print("Warning:", op.type, "input type differs from output type, placing on CPU")
- return False
+ @staticmethod
+ def constraint_matching_inputs_types(op):
+ "Both Input data types must match"
+ ifm_dtype = op.ifm.dtype
+ ifm2_dtype = op.ifm2.dtype
+ valid = ifm_dtype == ifm2_dtype
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
- if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
- print(
- "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type)
- )
- return False
+ @staticmethod
+ def constraint_matching_signed(op):
+ "For IFM that are signed, OFM must also be signed"
+ valid = True
+ ifm_dtype = op.ifm.dtype
+ ofm_dtype = op.ofm.dtype
+ if ifm_dtype.type & BaseType.Signed:
+ valid = bool(ofm_dtype.type & BaseType.Signed)
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
- # check shape
- if ifm_tensor.shape != ofm_tensor.shape:
- print("Warning:", op.type, "input shape differs from output shape, placing on CPU")
- return False
+ @staticmethod
+ def constraint_unsigned_valid(op):
+ "For IFM that are unsigned, OFM must either be the same type or int32"
+ valid = True
+ ifm_dtype = op.ifm.dtype
+ ofm_dtype = op.ofm.dtype
+ if ifm_dtype.type & BaseType.Unsigned:
+ valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
- elif op.type.is_relu_op():
- ifm_tensor, ofm_tensor = op.get_ifm_ofm()
- if np.isinf(ifm_tensor.quantization.scale_f32 / ofm_tensor.quantization.scale_f32):
- print("Warning:", op.type, "has an infinite scale value, placing on CPU")
- return False
+ @staticmethod
+ def constraint_inputs_int32(op):
+ "Both Input data types must be int32"
+ ifm_dtype = op.ifm.dtype
+ ifm2_dtype = op.ifm2.dtype
+ valid = (ifm_dtype == DataType.int32) and (ifm2_dtype == DataType.int32)
+ return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
- return True
+ @staticmethod
+ def constraint_output_int32(op):
+ "OFM must be int32"
+ ofm_dtype = op.ofm.dtype
+ valid = ofm_dtype == DataType.int32
+ return valid, f"Op has ofm_dtype={ofm_dtype}"
- @classmethod
- def check_bias_restrictions(cls, bias_tensor):
- # check data type
- if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
- print("Warning: bias tensor datatype must be int32 or int64, placing on CPU")
- return False
+ @staticmethod
+ def constraint_matching_quantization_parameters(op):
+ "Both Input quantization parameters must match OFM quantization parameters"
+ valid = True
+ extra = []
+ if not check_quantized_tens_scaling_equal(op.ofm, op.ifm):
+ valid = False
+ extra.append(op.ifm.name)
+ if not check_quantized_tens_scaling_equal(op.ofm, op.ifm2):
+ valid = False
+ extra.append(op.ifm2.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has tensors with different quantization parameters to the OFM '{op.ofm.name}': {extra}"
- # check if values fits in 40-bit
- if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
- for quant_value in bias_tensor.quant_values:
- if not (-(1 << 39) <= quant_value < (1 << 39)):
- print("Warning: bias tensor values are larger than 40 bits, placing on CPU")
- return False
+ @staticmethod
+ def constraint_elemwise_batch_size(op):
+ "Batch size must be 1 for Input tensors with more than 2 dimensions"
+ valid = True
+ extra = []
+ for tens in (op.ifm, op.ifm2):
+ # Unary ops have ifm2 as None
+ if tens is not None:
+ if (len(tens.shape) > 2) and (tens.shape[0] != 1):
+ valid = False
+ extra.append(tens.name)
+ extra = ", ".join(extra)
+ return valid, f"Op has invalid input tensors: {extra}"
- return True
+ @staticmethod
+ def constraint_matching_either_shapes(op):
+ "At least one Input's shape must match the OFM's shape"
+ ifm_shape = op.ifm.shape
+ ifm2_shape = op.ifm2.shape if op.ifm2 else None
+ ofm_shape = op.ofm.shape
+ valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
+ return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
+
+ @staticmethod
+ def constraint_alpha_valid(op):
+ "Alpha must not be negative"
+ alpha = op.attrs["alpha"]
+ valid = alpha >= 0
+ return valid, f"Op has alpha={alpha}"
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 665ebc2..595ea59 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -29,67 +29,9 @@ from ethosu.vela.test import testutil
support = SupportedOperators()
-def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
- qp = QuantizationParameters()
- in0 = Tensor(in_shape, DataType.uint8, "in")
- in0.quantization = qp
- in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
- in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
- in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
- out = Tensor(out_shape, DataType.uint8, "out")
- out.quantization = qp
- attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
- return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
-
-
-def create_strided_slice():
- # Creates a valid strided slice operator with some valid inputs/outputs
- op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
- op.attrs["begin_mask"] = 1
- op.attrs["end_mask"] = 9
- assert support.is_operator_supported(op)
- return op
-
-
-def test_strided_slice():
- # Tests support for StridedSlice operator
- op = create_strided_slice()
- # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
- op.attrs["new_axis_mask"] = 2
- assert support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["shrink_axis_mask"] = 3
- assert support.is_operator_supported(op)
- # But setting both to non-zero is not supported
- op.attrs["new_axis_mask"] = 2
- assert not support.is_operator_supported(op)
- # begin values must not be None
- op.inputs[1].values = None
- assert not support.is_operator_supported(op)
- # Unsupported strides
- op = create_strided_slice()
- op.inputs[3].values = [1, 1, 2, 1]
- assert not support.is_operator_supported(op)
- # Wrong number of input tensors
- op = create_strided_slice()
- op.add_input_tensor(op.inputs[0].clone())
- assert not support.is_operator_supported(op)
- # Unsupported ellipsis mask
- op = create_strided_slice()
- op.attrs["ellipsis_mask"] = 1
- assert not support.is_operator_supported(op)
- # Examples where end offset <= begin offset
- op = create_strided_slice()
- op.inputs[1].values = [0, 7, 2, 0]
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.inputs[2].values = [0, 7, 2, 0]
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["begin_mask"] = 0
- assert not support.is_operator_supported(op)
- op = create_strided_slice()
- op.attrs["end_mask"] = 0
+def test_constraint_tens_no_dynamic():
+ # Tensors cannot be dynamic (no shape, not a scalar)
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
assert not support.is_operator_supported(op)
@@ -99,18 +41,20 @@ def test_constraint_tens_defined_shape():
assert not support.is_operator_supported(op)
-def test_constraint_tens_output_shapeless():
- # Shapeless output is not allowed at all:
+def test_constraint_tens_output_scalar():
+ # Scalar output is not allowed at all:
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
+ op.ofm.values = 0.5
assert not support.is_operator_supported(op)
-def test_constraint_tens_input_shapeless():
+def test_constraint_tens_input_scalar():
# Shapeless input is allowed if its of a certain type:
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
assert support.is_operator_supported(op)
# Invalid shapeless input due to op type:
op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
+ op.ifm.values = 0.5
assert not support.is_operator_supported(op)
@@ -149,6 +93,7 @@ def test_constraint_tens_quant_none_check():
def test_constraint_tens_quant_scale():
# Quantization scale cannot be infinit
qp = QuantizationParameters()
+ qp.zero_point = 0
qp.scale_f32 = np.inf
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
assert not support.is_operator_supported(op)
@@ -219,12 +164,12 @@ def test_constraint_weights_type():
assert not support.is_operator_supported(op)
-def test_constraint_weights_nonconst():
+def test_constraint_weights_const():
# Weight tensor cannot be non-const tensors
op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 1, "stride_h": 1}
weights = Tensor([64, 64, 1, 1], DataType.uint8, "weights")
- weights.quantization = QuantizationParameters()
+ weights.quantization = testutil.default_quant_params()
op.add_input_tensor(weights)
assert not support.is_operator_supported(op)
@@ -251,7 +196,7 @@ def test_constraint_bias_40bit():
op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
bias = Tensor([1, 1, 1, 1], DataType.int64, "bias")
- bias.quant_values = np.array([0x1FF_FFFF_FFFF])
+ bias.quant_values = np.array([0x01FF_FFFF_FFFF])
op.add_input_tensor(bias)
assert not support.is_operator_supported(op)
@@ -260,3 +205,452 @@ def test_constraint_batch_size():
op = testutil.create_op_with_quant_tensors(Op.Conv2D, [2, 8, 8, 8], [1, 8, 8, 8], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert not support.is_operator_supported(op)
+
+
+def test_constraint_quant_scale_inf():
+ op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.ofm.quantization.scale_f32 = np.float32(1e-39)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_depth_multiplier():
+ # Valid. Depth multiplier is 1 so no further constraints
+ op = testutil.create_op_with_quant_tensors(
+ Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
+ )
+ op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 1}
+ assert support.is_operator_supported(op)
+ # Invalid. Depth multiplier doesnt equal ofm channel
+ op = testutil.create_op_with_quant_tensors(
+ Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1]
+ )
+ op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
+ assert not support.is_operator_supported(op)
+ # Valid. Depth multiplier is equal to ofm channel
+ op = testutil.create_op_with_quant_tensors(
+ Op.DepthwiseConv2DBias, [1, 1, 1, 1], [1, 1, 1, 2], weights_shape=[1, 1, 1, 1]
+ )
+ op.attrs = {"stride_w": 1, "stride_h": 1, "depth_multiplier": 2}
+ assert support.is_operator_supported(op)
+
+
+def test_constraint_tconv_stride():
+ # Strides must be 2
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 1, "stride_h": 1, "padding": b"SAME"}
+ ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tconv_same():
+ # Valid
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": b"SAME"}
+ ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert support.is_operator_supported(op)
+ # Invalid
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[1, 1, 1, 1])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": b"SAME"}
+ ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_tconv_valid():
+ # Valid
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[4, 4, 1, 1])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": b"VALID"}
+ ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert support.is_operator_supported(op)
+ # Invalid
+ op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[2, 2, 1, 1])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": b"VALID"}
+ ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
+ ifm.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_matching_in_out_types():
+ # Valid
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": b"SAME"}
+ assert support.is_operator_supported(op)
+ # Invalid. datatypes for ifm and ofm must match (default uint8)
+ op.ifm.dtype = DataType.int8
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_filter_type():
+ # Filter width/height must be integers
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": b"SAME"}
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_filter_range():
+ # Avg pool restrictions are dependent on padding:
+ # SAME padding restricts both W and H to max 8
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 20, "filter_height": 20, "padding": b"SAME"}
+ assert not support.is_operator_supported(op)
+ # VALID padding limits are much larger
+ op.attrs["padding"] = b"VALID"
+ assert support.is_operator_supported(op)
+
+
+def test_constraint_filter_height_range_valid_pad():
+ # Avg pool restrictions are dependent on padding:
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": b"VALID"}
+ assert support.is_operator_supported(op)
+ # VALID padding restricts to 256 in filter height
+ op.attrs["filter_height"] = 257
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_filter_product_height_range_valid_pad():
+ # Avg pool restrictions are dependent on padding:
+ op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": b"VALID"}
+ assert support.is_operator_supported(op)
+ # VALID padding restricts filter W x H to 256x256
+ op.attrs["filter_width"] = 257
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_filter_height_range():
+ # Max pool restrictions arent dependent on padding
+ op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": b"SAME"}
+ assert support.is_operator_supported(op)
+ # Restricts to 256 in filter height
+ op.attrs["filter_height"] = 257
+ assert not support.is_operator_supported(op)
+ # Doesnt matter if SAME or VALID
+ op.attrs["padding"] = b"VALID"
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_filter_product_height_range():
+ # Max pool restrictions arent dependent on padding
+ op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": b"SAME"}
+ assert support.is_operator_supported(op)
+ # Restricts filter W x H to 256x256
+ op.attrs["filter_width"] = 257
+ assert not support.is_operator_supported(op)
+ # Doesnt matter if SAME or VALID
+ op.attrs["padding"] = b"VALID"
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_resize():
+ # IFM W and H == 1
+ op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 1, 1, 8], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # IFM == OFM
+ op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 8, 8, 8], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # IFM x2 == OFM ; align_corners = False
+ op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # IFM x2 -1 == OFM ; align_corners = True
+ op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 7, 7, 8])
+ op.attrs["align_corners"] = True
+ assert support.is_operator_supported(op)
+ # Invalid cases
+ op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 20, 20, 8])
+ assert not support.is_operator_supported(op)
+ op.attrs["align_corners"] = True
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_matching_shapes():
+ # Softmax requires the ifm and ofm shapes to match
+ op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
+ assert support.is_operator_supported(op)
+
+
+def test_constraint_splitv_inferred():
+ # SplitV requires a maximum of one inferred shape (-1)
+ qp = testutil.default_quant_params()
+ op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
+ sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
+ op.add_input_tensor(sizes)
+ assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
+ sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
+ op.add_input_tensor(sizes)
+ assert support.is_operator_supported(op)
+
+
+def test_constraint_concat_pass():
+ # A working concat
+ op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert support.is_operator_supported(op)
+
+
+def test_constraint_axis_exists():
+ # Missing axis attribute
+ op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_axis_valid():
+ # Invalid axis attribute
+ op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 7
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_matching_dimensionality():
+ # Mismatching dimensionality: 4D+2D=4D
+ op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_valid_dimensions():
+ # Mismatching dimension value:
+ # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
+ op = testutil.create_op_with_quant_tensors(Op.Concat, [1, 1, 1, 4], [1, 1, 1, 8])
+ ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
+ ifm2.quantization = testutil.default_quant_params()
+ op.add_input_tensor(ifm2)
+ op.attrs["axis"] = 3
+ assert not support.is_operator_supported(op)
+
+
+def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, DataType.uint8, "in")
+ in0.quantization = qp
+ in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
+ in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
+ in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
+ out = Tensor(out_shape, DataType.uint8, "out")
+ out.quantization = qp
+ attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
+ return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
+
+
+def create_strided_slice():
+ # Creates a valid strided slice operator with some valid inputs/outputs
+ op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
+ op.attrs["begin_mask"] = 1
+ op.attrs["end_mask"] = 9
+ assert support.is_operator_supported(op)
+ return op
+
+
+def test_constraint_stridedslice_input_count():
+ # Wrong number of input tensors
+ op = create_strided_slice()
+ op.add_input_tensor(op.inputs[0].clone())
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_stridedslice_inputs_const():
+ # begin, end, stride values must not be None
+ op = create_strided_slice()
+ op.inputs[1].values = None
+ assert not support.is_operator_supported(op)
+ op = create_strided_slice()
+ op.inputs[2].values = None
+ assert not support.is_operator_supported(op)
+ op = create_strided_slice()
+ op.inputs[3].values = None
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_stridedslice_tens_size_matches():
+ op = create_strided_slice()
+ op.inputs[1].values = [1, 1, 1, 1, 1, 1, 1, 1]
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_stridedslice_stride_values():
+ # Unsupported strides
+ op = create_strided_slice()
+ op.inputs[3].values = [1, 1, 2, 1]
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_ellipsis_mask():
+ # Unsupported ellipsis mask
+ op = create_strided_slice()
+ op.attrs["ellipsis_mask"] = 1
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_axis_masks():
+ op = create_strided_slice()
+ # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
+ op.attrs["new_axis_mask"] = 2
+ assert support.is_operator_supported(op)
+ op = create_strided_slice()
+ op.attrs["shrink_axis_mask"] = 3
+ assert support.is_operator_supported(op)
+ # But setting both to non-zero is not supported
+ op.attrs["new_axis_mask"] = 2
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_slice_ranges():
+ # Examples where end offset <= begin offset
+ op = create_strided_slice()
+ op.inputs[1].values = [0, 7, 2, 0]
+ assert not support.is_operator_supported(op)
+ op = create_strided_slice()
+ op.inputs[2].values = [0, 7, 2, 0]
+ assert not support.is_operator_supported(op)
+ op = create_strided_slice()
+ op.attrs["begin_mask"] = 0
+ assert not support.is_operator_supported(op)
+ op = create_strided_slice()
+ op.attrs["end_mask"] = 0
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_matching_inputs_types():
+ # input data types must match (default is uint8)
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ op.ifm2.dtype = DataType.int8
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_matching_signed():
+ # signed inputs require output to also be signed
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+ op.ofm.dtype = DataType.uint8
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_unsigned_valid():
+ # unsigned inputs require output to be either:
+ op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ # the same (default uint8)
+ assert support.is_operator_supported(op)
+ op.ofm.dtype = DataType.int8
+ assert not support.is_operator_supported(op)
+ op.ofm.dtype = DataType.int16
+ assert not support.is_operator_supported(op)
+ # or int32
+ op.ofm.dtype = DataType.int32
+ assert support.is_operator_supported(op)
+
+
+def test_constraint_inputs_int32():
+ # both inputs must be type int32
+ op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ assert not support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
+ assert support.is_operator_supported(op)
+ op.ifm2.dtype = DataType.int16
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_output_int32():
+ # output must be type int32
+ op = testutil.create_elemwise_op(Op.SHL, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
+ assert support.is_operator_supported(op)
+ op.ofm.dtype = DataType.int16
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_matching_quantization_parameters():
+ qp = QuantizationParameters()
+ qp.scale_f32 = np.float32(1.5)
+ qp.zero_point = 128
+ # valid - all matching (uses default quant params)
+ op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # invalid - ifm mismatch ofm
+ op.ifm.quantization = qp
+ assert not support.is_operator_supported(op)
+ # invalid - ifm2 mismatch ofm
+ op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ op.ifm2.quantization = qp
+ assert not support.is_operator_supported(op)
+ # invalid - both ifm and ifm2 mismatch ofm
+ op = testutil.create_elemwise_op(Op.Minimum, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
+ op.ifm.quantization = qp
+ op.ifm2.quantization = qp
+ assert not support.is_operator_supported(op)
+ # valid - all matching
+ op.ofm.quantization = qp
+ assert support.is_operator_supported(op)
+
+
+def test_constraint_elemwise_batch_size():
+ # BINARY CASE
+ # Batch can be >1 if dims is <=2D
+ op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [2, 2], [2, 2])
+ assert support.is_operator_supported(op)
+ # For dims >2D, batch must be 1
+ op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2], [1, 2, 2], [1, 2, 2])
+ assert support.is_operator_supported(op)
+ # invalid case
+ op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
+ assert not support.is_operator_supported(op)
+
+ # UNARY CASE
+ # Batch can be >1 if dims is <=2D
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
+ assert support.is_operator_supported(op)
+ # For dims >2D, batch must be 1
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2], None, [1, 2, 2], datatype=DataType.int32)
+ assert support.is_operator_supported(op)
+ # invalid case
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_matching_either_shapes():
+ # BINARY CASE
+ # At least one ifm shape must match ofm's shape
+ op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [4, 4], [2, 2])
+ assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [2, 2], [2, 2])
+ assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
+ assert not support.is_operator_supported(op)
+
+ # UNARY CASE
+ # No second input so this is treated the same as requiring ifm shape to match ofm shape
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
+ assert support.is_operator_supported(op)
+ op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_alpha_valid():
+ # Alpha cannot be negative
+ op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
+ op.attrs["alpha"] = 0
+ assert support.is_operator_supported(op)
+ op.attrs["alpha"] = -1
+ assert not support.is_operator_supported(op)
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index 92bf53d..b06008a 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -39,16 +39,23 @@ def create_arch():
)
+def default_quant_params():
+ qp = QuantizationParameters()
+ qp.scale_f32 = np.float32(1)
+ qp.zero_point = 0
+ return qp
+
+
def create_elemwise_op(
- type,
+ op_type,
name,
ifm_shape,
ifm2_shape,
ofm_shape,
datatype=DataType.uint8,
- ifm_quant=QuantizationParameters(),
- ifm2_quant=QuantizationParameters(),
- ofm_quant=QuantizationParameters(),
+ ifm_quant=default_quant_params(),
+ ifm2_quant=default_quant_params(),
+ ofm_quant=default_quant_params(),
):
# Creates elementwise operation with constant IFM/IFM2
if datatype.size_in_bytes() == 1:
@@ -57,15 +64,16 @@ def create_elemwise_op(
np_type = np.int16
else:
np_type = np.int32
- op = Operation(type, name)
+ op = Operation(op_type, name)
op.add_input_tensor(
create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type, quantization=ifm_quant)
)
- op.add_input_tensor(
- create_const_tensor(
- name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type, quantization=ifm2_quant
+ if ifm2_shape is not None:
+ op.add_input_tensor(
+ create_const_tensor(
+ name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type, quantization=ifm2_quant
+ )
)
- )
ofm = Tensor(ofm_shape, datatype, name + "_ofm")
ofm.quantization = ofm_quant
op.set_output_tensor(ofm)
@@ -73,11 +81,10 @@ def create_elemwise_op(
def create_op_with_quant_tensors(op_type, ifm_shape, ofm_shape, weights_shape=None, datatype=DataType.uint8):
- qp = QuantizationParameters()
ifm = Tensor(ifm_shape, datatype, "in")
- ifm.quantization = qp
+ ifm.quantization = default_quant_params()
ofm = Tensor(ofm_shape, datatype, "out")
- ofm.quantization = qp
+ ofm.quantization = default_quant_params()
op = Operation(op_type, "op")
op.add_input_tensor(ifm)
op.set_output_tensor(ofm)
@@ -89,6 +96,7 @@ def create_op_with_quant_tensors(op_type, ifm_shape, ofm_shape, weights_shape=No
np_type = np.int16
else:
np_type = np.int32
+ qp = default_quant_params()
qp.zero_point = np.zeros(weights_shape)
weights = create_const_tensor(
"weights", weights_shape, datatype, np.zeros(weights_shape), np_type, quantization=qp