aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-10-14 09:30:02 +0100
committerMichael McGeagh <michael.mcgeagh@arm.com>2020-10-19 12:06:01 +0100
commit1f951fc47abd52db0ac048802dab0c95b149d7b8 (patch)
tree2d4337a75f557813bdac3a9c5a1272a7bbc792b6 /ethosu/vela/supported_operators.py
parentc6ac1944d9934faf6d22825cdd3273afe55432a4 (diff)
downloadethos-u-vela-1f951fc47abd52db0ac048802dab0c95b149d7b8.tar.gz
MLBEDSW-2412 Refactor constraints for conv ops
Using a new system to report constraints, replaced existing functionality for checking conv-like ops. This new system will allow reporting of all constraints regardless of any input network. Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: If81177deca2a3b57c9dd9a3a08868cbc9cef0c23
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py243
1 files changed, 157 insertions, 86 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 66c74fce..f4dd5796 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -15,6 +15,8 @@
# limitations under the License.
# Description:
# The SupportedOperators class which is a collection of all supported operators and parameter checks.
+from collections import defaultdict
+
import numpy as np
from .data_type import BaseType
@@ -43,6 +45,7 @@ class SupportedOperators:
convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
+ convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
max_pooling_ops = Op.op_set(Op.is_maxpool_op)
avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
@@ -51,12 +54,8 @@ class SupportedOperators:
mac_main_ops = (
# RNN/LSTM/GRU
set((Op.BlockLSTM,))
- # convolutions
- | convolution_ops
- # depth-wise convolutions
- | depthwise_convolution_ops
- # transpose convolutions
- | transpose_convolution_ops
+ # conv/depthwiseconv/transposeconv
+ | convolution_like_ops
# pooling
| pooling_ops
# resizing/upscaling
@@ -88,17 +87,21 @@ class SupportedOperators:
shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
- supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
+ # Supported data types
+ supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
+ supported_bias_dtypes = set((DataType.int32, DataType.int64))
# Defined ranges for allowed values:
tens_dim_range = (1, 65535)
+ stride_range = (1, 3)
+ dilation_range = (1, 2)
+ dilated_height_range = (1, 64)
+ dilated_product_range = (1, 64 * 64)
+ weights_limit = 127 * 65536
def __init__(self):
# Setup supported operator restriction checkers
self.supported_operator_restrictions = {}
self.supported_operator_restrictions.update(
- {op: self.check_convolution_restrictions for op in SupportedOperators.convolution_ops}
- )
- self.supported_operator_restrictions.update(
{op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
)
self.supported_operator_restrictions.update(
@@ -134,13 +137,37 @@ class SupportedOperators:
self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
self.generic_constraints.append(SupportedOperators.constraint_faf)
+ # Setup specific constraints. The key in the dictionary must be a tuple of op types the constraints apply to
+ self.specific_constraints = defaultdict(list)
+ # Conv-like ops have the same checks applied to them:
+ conv_like_ops = tuple(SupportedOperators.convolution_like_ops)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_type)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_range)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_type)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_range)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_height_range)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_product_range)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_type)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_nonconst)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_limit)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_type)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_40bit)
+ self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_batch_size)
+
+ def get_constraints_list(self, op_type):
+ constraint_list = list(self.generic_constraints)
+ for ops in self.specific_constraints:
+ if op_type in ops:
+ constraint_list.extend(self.specific_constraints[ops])
+ return constraint_list
def is_operator_supported(self, op):
if op.type not in SupportedOperators.supported_operators:
if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
print("Info: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
return False
- for constraint in self.generic_constraints:
+
+ for constraint in self.get_constraints_list(op.type):
valid, extra = constraint(op)
if not valid:
print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
@@ -148,6 +175,7 @@ class SupportedOperators:
if extra:
print(" {}".format(extra))
return False
+
if op.type in self.supported_operator_restrictions:
return self.supported_operator_restrictions[op.type](op)
return True
@@ -186,7 +214,7 @@ class SupportedOperators:
if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
valid = False
extra.append(tens.name)
- extra = "Op '{}' has shapeless input tensor(s): {}".format(op.name, ", ".join(extra))
+ extra = "Op has shapeless input tensor(s): {}".format(", ".join(extra))
return valid, extra
@staticmethod
@@ -202,15 +230,15 @@ class SupportedOperators:
return valid, ", ".join(extra)
@classmethod
- @docstring_format_args([supported_dtypes])
+ @docstring_format_args([supported_op_dtypes])
def constraint_tens_dtype(cls, op):
- "Tensors must be of type: {}"
+ "Input(s), Output and Weight Tensors must be of type: {}"
valid = True
extra = []
tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
tensors = tensors if tensors else op.inputs
for tens in tensors:
- if tens.dtype not in cls.supported_dtypes:
+ if tens.dtype not in cls.supported_op_dtypes:
valid = False
extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
return valid, ", ".join(extra)
@@ -227,13 +255,13 @@ class SupportedOperators:
if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
valid = False
extra.append(tens.name)
- extra = "Op '{}' has int32 tensor(s): {}".format(op.name, ", ".join(extra))
+ extra = "Op has int32 tensor(s): {}".format(", ".join(extra))
return valid, extra
@classmethod
@docstring_format_args(tens_dim_range)
def constraint_tens_dimension(cls, op):
- "Tensor dimensions must be in the range {}-{} (inclusive)"
+ "Tensor dimensions must be in the range [{}, {}]"
tens_min, tens_max = cls.tens_dim_range
valid = True
extra = []
@@ -275,85 +303,129 @@ class SupportedOperators:
"The fused activation function (if present) must be one of type: {}"
faf = op.activation
valid = (faf is None) or (faf in cls.supported_fused_activations)
- extra = "Op '{}' has its fused activation function as: {}".format(op.name, faf)
+ extra = "Op has its fused activation function as: {}".format(faf)
+ return valid, extra
+
+ @staticmethod
+ def constraint_stride_type(op):
+ "Stride values for both width and height must be integer types"
+ w = op.attrs["stride_w"]
+ h = op.attrs["stride_h"]
+ valid = is_integer(w) and is_integer(h)
+ extra = "Op has stride WxH as: {}x{}".format(repr(w), repr(h))
return valid, extra
@classmethod
- def check_convolution_restrictions(cls, op):
- # check stride
- stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
- if not is_integer(stride_w) or not is_integer(stride_h):
- print("Warning:", op.type, "has non-integer stride, placing on CPU")
- return False
- if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
- print(
- "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
- op.type, stride_w, stride_h
- )
- )
- return False
+ @docstring_format_args(stride_range)
+ def constraint_stride_range(cls, op):
+ "Stride values for both width and height must be in the range [{}, {}]"
+ w = op.attrs["stride_w"]
+ h = op.attrs["stride_h"]
+ stride_min, stride_max = cls.stride_range
+ valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max)
+ extra = "Op has stride WxH as: {}x{}".format(w, h)
+ return valid, extra
- # check dilation
- dilation_w_factor = op.attrs.get("dilation_w_factor", 1)
- dilation_h_factor = op.attrs.get("dilation_h_factor", 1)
- if not is_integer(dilation_w_factor) or not is_integer(dilation_h_factor):
- print("Warning:", op.type, "has non-integer dilation factor, placing on CPU")
- return False
- if not 1 <= dilation_w_factor <= 2 or not 1 <= dilation_h_factor <= 2:
- print(
- "Warning:",
- op.type,
- "has dilation factors ({}, {}), only factors in range [1, 2] are allowed. Placing on CPU".format(
- dilation_w_factor, dilation_h_factor
- ),
- )
- return False
+ @staticmethod
+ def constraint_dilation_type(op):
+ "Dilation factor values for both width and height must be integer types"
+ w = op.attrs.get("dilation_w_factor", 1)
+ h = op.attrs.get("dilation_h_factor", 1)
+ valid = is_integer(w) and is_integer(h)
+ extra = "Op has dilation factor WxH as: {}x{}".format(repr(w), repr(h))
+ return valid, extra
- # check data type
- ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
- if weight_tensor.element_size() > 1:
- print("Warning: only 8-bit weights are supported, placing on CPU")
- return False
+ @classmethod
+ @docstring_format_args(dilation_range)
+ def constraint_dilation_range(cls, op):
+ "Dilation factor values for both width and height must be in the range [{}, {}]"
+ w = op.attrs.get("dilation_w_factor", 1)
+ h = op.attrs.get("dilation_h_factor", 1)
+ dilation_min, dilation_max = cls.dilation_range
+ valid = (dilation_min <= w <= dilation_max) and (dilation_min <= h <= dilation_max)
+ extra = "Op has dilation factor WxH as: {}x{}".format(w, h)
+ return valid, extra
- if not cls.check_bias_restrictions(bias_tensor):
- return False
+ @classmethod
+ @docstring_format_args(dilated_height_range)
+ def constraint_dilated_height_range(cls, op):
+ "Dilated kernel height must be in the range [{}, {}]"
+ h = (op.weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
+ dilated_height_min, dilated_height_max = cls.dilated_height_range
+ valid = dilated_height_min <= h <= dilated_height_max
+ extra = "Op has dilated kernel height as: {}".format(h)
+ return valid, extra
- # check kernel size [HWIO]
- dilated_weight_w = (weight_tensor.shape[1] - 1) * dilation_w_factor + 1
- dilated_weight_h = (weight_tensor.shape[0] - 1) * dilation_h_factor + 1
+ @classmethod
+ @docstring_format_args(dilated_product_range)
+ def constraint_dilated_product_range(cls, op):
+ "Product of dilated kernel width and height must be in the range [{}, {}]"
+ weights = op.weights
+ w = (weights.shape[1] - 1) * op.attrs.get("dilation_w_factor", 1) + 1
+ h = (weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
+ product = w * h
+ dilated_product_min, dilated_product_max = cls.dilated_product_range
+ valid = dilated_product_min <= product <= dilated_product_max
+ extra = "Op has product of dilated kernel width and height as: {}".format(product)
+ return valid, extra
- # kernel limits
- if not 1 <= dilated_weight_h <= 64:
- print("Warning:", op.type, "has kernel height outside of range [1, 64], placing on CPU")
- return False
- if not 1 <= dilated_weight_w * dilated_weight_h <= 64 * 64:
- print(
- "Warning: product of kernel width and height must be >= 1 and not exceed 64 * 64 ({}),".format(64 * 64),
- "placing on CPU",
- )
- return False
+ @staticmethod
+ def constraint_weights_type(op):
+ "Weight Tensor must be 8-bit"
+ weights = op.weights
+ valid = weights.element_size() == 1
+ extra = "Tensor '{}' is {}-bit".format(weights.name, int(weights.element_size() * 8))
+ return valid, extra
- # check non const weights
- if weight_tensor.values is None:
- print("Warning:", op.type, "has non-constant weights, placing on CPU")
- return False
+ @staticmethod
+ def constraint_weights_nonconst(op):
+ "Weight tensor cannot be non-constant"
+ weights = op.weights
+ valid = weights.values is not None
+ extra = "Tensor '{}' has non-constant values".format(weights.name)
+ return valid, extra
- # check weight sums over [HWI]
- zero_point = weight_tensor.quantization.zero_point
- quant_weights = weight_tensor.quant_values.astype(np.int64)
- weights = quant_weights - zero_point
- totals = np.sum(np.absolute(weights), axis=(0, 1, 2))
+ @classmethod
+ @docstring_format_args([weights_limit])
+ def constraint_weights_limit(cls, op):
+ "The sum of the weights cannot exceed {}"
+ weights = op.weights
+ values = weights.quant_values.astype(np.int64) - weights.quantization.zero_point
+ limit = np.amax(np.sum(np.absolute(values), axis=(0, 1, 2)))
+ valid = limit <= cls.weights_limit
+ extra = "Tensor '{}' has the sum of weights: {}".format(weights.name, limit)
+ return valid, extra
- if np.amax(totals) > 127 * 65536:
- print("Warning: sum of weights exceeds 127 * 65536 ({}), placing on CPU".format(127 * 65536))
- return False
+ @classmethod
+ @docstring_format_args([supported_bias_dtypes])
+ def constraint_bias_type(cls, op):
+ "Optional Bias Tensor must be of type: {}"
+ valid = True
+ extra = ""
+ bias = op.bias
+ if bias:
+ valid = bias.dtype in cls.supported_bias_dtypes
+ extra = "Tensor '{}' has data type: {}".format(bias.name, bias.dtype)
+ return valid, extra
- # check batch size
- if ifm_tensor.shape[0] != 1:
- print("Warning: only batch sizes of 1 are supported, placing on CPU")
- return False
+ @staticmethod
+ def constraint_bias_40bit(op):
+ "Optional Bias Tensor values must fit within 40-bits"
+ valid = True
+ extra = ""
+ bias = op.bias
+ if bias and bias.dtype == DataType.int64:
+ valid = all(len(bin(quant_value)[2:]) <= 40 for quant_value in bias.quant_values)
+ extra = "Tensor '{}' has values larger than 40-bits".format(bias.name)
+ return valid, extra
- return True
+ @staticmethod
+ def constraint_batch_size(op):
+ "IFM Tensor batch size must be 1"
+ ifm = op.ifm
+ valid = ifm.shape[0] == 1
+ extra = "Tensor '{}' has batch size: {}".format(ifm.name, ifm.shape[0])
+ return valid, extra
@classmethod
def check_depthwise_convolution_restrictions(cls, op):
@@ -368,7 +440,7 @@ class SupportedOperators:
"Placing on CPU",
)
return False
- return cls.check_convolution_restrictions(op)
+ return True
@classmethod
def check_transpose_convolution_restrictions(cls, op):
@@ -403,8 +475,7 @@ class SupportedOperators:
"minus difference between kernel size and stride. Placing on CPU",
)
return False
-
- return cls.check_convolution_restrictions(op)
+ return True
@classmethod
def check_pooling_restrictions(cls, op):