From 8359a474e4f125382fd7b7d5431a612f6013f107 Mon Sep 17 00:00:00 2001 From: Dwight Lidman Date: Mon, 28 Sep 2020 15:53:40 +0200 Subject: MLBEDSW-3061: Update supported_operators.py This commit changes and amends some parts of the restriction functions in order to make sure operators are correctly placed. Signed-off-by: Dwight Lidman Change-Id: I336cf33a874c9078a5bbf81ce129ff917dbc5e9a --- ethosu/vela/numeric_util.py | 10 ++ ethosu/vela/supported_operators.py | 161 +++++++++++++++++++++++---- ethosu/vela/test/test_supported_operators.py | 9 +- ethosu/vela/test/testutil.py | 24 +++- 4 files changed, 178 insertions(+), 26 deletions(-) diff --git a/ethosu/vela/numeric_util.py b/ethosu/vela/numeric_util.py index 4ebef8e5..20aa4a05 100644 --- a/ethosu/vela/numeric_util.py +++ b/ethosu/vela/numeric_util.py @@ -93,3 +93,13 @@ def full_shape(dim, shape, fill): def overlaps(start1, end1, start2, end2): return start1 < end2 and start2 < end1 + + +def is_integer(num): + if isinstance(num, (int, np.integer)): + return True + if type(num) is float and num.is_integer(): + return True + if isinstance(num, np.inexact) and np.mod(num, 1) == 0: + return True + return False diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 3d4a09f3..357e7fe8 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -19,6 +19,7 @@ import numpy as np from .data_type import BaseType from .data_type import DataType +from .numeric_util import is_integer from .operation import get_slice_offsets from .operation import Op @@ -130,6 +131,7 @@ class SupportedOperators: self.generic_constraints.append(SupportedOperators.constraint_tens_dimension) self.generic_constraints.append(SupportedOperators.constraint_faf) self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale) + self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check) def is_operator_supported(self, op): if op.type not in SupportedOperators.supported_operators: @@ -235,36 +237,76 @@ class SupportedOperators: extra.append("quantization.scale_f32={}".format(tens.quantization.scale_f32)) return valid, " ".join(extra) + @staticmethod + def constraint_tens_quant_none_check(op): + "Tensors must have quantization parameters" + valid = True + extra = [] + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + for tens in tensors: + if tens.quantization is None: + valid = False + extra.append("Tensor '{}' has no quantization parameters".format(tens.name)) + return valid, ", ".join(extra) + @classmethod def check_convolution_restrictions(cls, op): # check stride - if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3: + stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"] + if not is_integer(stride_w) or not is_integer(stride_h): + print("Warning:", op.type, "has non-integer stride, placing on CPU") + return False + if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3: + print( + "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format( + op.type, stride_w, stride_h + ) + ) return False # check dilation dilation_w_factor = op.attrs.get("dilation_w_factor", 1) dilation_h_factor = op.attrs.get("dilation_h_factor", 1) - if dilation_w_factor > 2 or dilation_h_factor > 2: + if not is_integer(dilation_w_factor) or not is_integer(dilation_h_factor): + print("Warning:", op.type, "has non-integer dilation factor, placing on CPU") + return False + if not 1 <= dilation_w_factor <= 2 or not 1 <= dilation_h_factor <= 2: + print( + "Warning:", + op.type, + "has dilation factors ({}, {}), only factors in range [1, 2] are allowed. Placing on CPU".format( + dilation_w_factor, dilation_h_factor + ), + ) return False # check data type ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm() if weight_tensor.element_size() > 1: + print("Warning: only 8-bit weights are supported, placing on CPU") return False if not cls.check_bias_restrictions(bias_tensor): return False # check kernel size [HWIO] - dilated_weight_w = weight_tensor.shape[1] + (weight_tensor.shape[1] - 1) * (dilation_w_factor - 1) - dilated_weight_h = weight_tensor.shape[0] + (weight_tensor.shape[0] - 1) * (dilation_h_factor - 1) + dilated_weight_w = (weight_tensor.shape[1] - 1) * dilation_w_factor + 1 + dilated_weight_h = (weight_tensor.shape[0] - 1) * dilation_h_factor + 1 - if dilated_weight_w > 64 or dilated_weight_h > 64: + # kernel limits + if not 1 <= dilated_weight_h <= 64: + print("Warning:", op.type, "has kernel height outside of range [1, 64], placing on CPU") + return False + if not 1 <= dilated_weight_w * dilated_weight_h <= 64 * 64: + print( + "Warning: product of kernel width and height must be >= 1 and not exceed 64 * 64 ({}),".format(64 * 64), + "placing on CPU", + ) return False # check non const weights if weight_tensor.values is None: - print("Warning:", op.type, "has non-const weights, placing on CPU") + print("Warning:", op.type, "has non-constant weights, placing on CPU") return False # check weight sums over [HWI] @@ -274,10 +316,12 @@ class SupportedOperators: totals = np.sum(np.absolute(weights), axis=(0, 1, 2)) if np.amax(totals) > 127 * 65536: + print("Warning: sum of weights exceeds 127 * 65536 ({}), placing on CPU".format(127 * 65536)) return False # check batch size if ifm_tensor.shape[0] != 1: + print("Warning: only batch sizes of 1 are supported, placing on CPU") return False return True @@ -289,6 +333,11 @@ class SupportedOperators: if op.attrs["depth_multiplier"] > 1 and not ( (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"]) ): + print( + "Warning: for depth multipliers > 1,", + "number of input channels must be 1 and number of output channels must be equal to depth multiplier.", + "Placing on CPU", + ) return False return cls.check_convolution_restrictions(op) @@ -296,7 +345,8 @@ class SupportedOperators: def check_transpose_convolution_restrictions(cls, op): # check stride stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"] - if stride_h != stride_w != 2: + if stride_h != 2 or stride_w != 2: + print("Warning: stride must be equal to 2, placing on CPU") return False # check output dimensions @@ -305,12 +355,24 @@ class SupportedOperators: ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2] if op.attrs["padding"] == b"SAME": if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w): + print( + "Warning: for", + op.type, + "using SAME padding, output dimensions must equal input dimensions multiplied by stride.", + "Placing on CPU", + ) return False elif op.attrs["padding"] == b"VALID": kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1] if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or ( ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0) ): + print( + "Warning: for", + op.type, + "using VALID padding, output dimensions must equal input dimensions multiplied by stride,", + "minus difference between kernel size and stride. Placing on CPU", + ) return False return cls.check_convolution_restrictions(op) @@ -318,33 +380,56 @@ class SupportedOperators: @classmethod def check_pooling_restrictions(cls, op): # check stride - if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3: + stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"] + if not is_integer(stride_w) or not is_integer(stride_h): + print("Warning:", op.type, "has non-integer stride, placing on CPU") + return False + if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3: + print( + "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format( + op.type, stride_w, stride_h + ) + ) return False # check data type ifm_tensor, ofm_tensor = op.get_ifm_ofm() if ifm_tensor.dtype != ofm_tensor.dtype: if op.type != Op.ReduceSum: + print("Warning: input data type doesn't match output data type, placing on CPU") return False # TODO: else check ReduceSum restrictions. # check batch size if ifm_tensor.shape[0] != 1: + print("Warning: input batch size must be 1, placing on CPU") return False - if op.type in cls.avg_pooling_ops: - # check kernel size - if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8): + # check kernel size + kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"] + if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME": + if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8: + print( + "Warning:", + op.type, + "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format( + kernel_w, kernel_h + ), + ) return False - if op.attrs["padding"] == b"VALID" and ( - op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256 - ): + if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops: + if not 1 <= kernel_w * kernel_h <= 256 * 256: + print( + "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format( + 256 * 256 + ), + "placing on CPU", + ) return False - - if op.type in cls.max_pooling_ops: - # check kernel size (any padding) - if op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256: + if not 1 <= kernel_h <= 256: + print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU") return False + return True @classmethod @@ -368,8 +453,15 @@ class SupportedOperators: @classmethod def check_vector_product_restrictions(cls, op): # check data type - _, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm() + ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm() if weight_tensor.element_size() > 1: + print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type)) + return False + + # check batch size + batch_sizes = {1, 2, 4, 8} + if ifm_tensor.shape[0] not in batch_sizes: + print("Warning: only batch sizes {} supported for {}, placing on CPU".format(batch_sizes, op.type)) return False if not cls.check_bias_restrictions(bias_tensor): @@ -391,43 +483,65 @@ class SupportedOperators: op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops and ifm_tensor.dtype != ofm_tensor.dtype ): + print("Warning:", op.type, "must have same input and output datatype, placing on CPU") return False if op.type in cls.binary_elem_wise_add_mul_sub: # both inputs must have same type if ifm_tensor.dtype != ifm2_tensor.dtype: + print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU") return False # signed input check if ifm_tensor.dtype.type & BaseType.Signed: # output must be signed if ofm_tensor.dtype.type & BaseType.Unsigned: + print("Warning: only signed output types supported for {}, placing on CPU".format(op.type)) return False # and 8, 16 or 32-bit - if ofm_tensor.element_size() not in (1, 2, 4): + bit_lengths = {8, 16, 32} + if ofm_tensor.element_size() * 8 not in bit_lengths: + print( + "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths) + ) return False # unsigned input check, output must be same type or int32 if ifm_tensor.dtype.type & BaseType.Unsigned and not ( ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32 ): + print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU") return False elif op.type in cls.binary_elem_wise_shift_ops: if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32: + print("Warning:", op.type, "input datatypes are not int32, placing on CPU") return False if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32: + print("Warning:", op.type, "output datatype is not int32, placing on CPU") return False # check batch size if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1: + print( + "Warning:", + op.type, + "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU", + ) return False if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1: + print( + "Warning:", + op.type, + "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU", + ) return False # negative alpha values are not supported if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0: + print("Warning:", op.type, "has negative alpha, placing on CPU") return False # check if ifm or ifm2 has ofm shape if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape: + print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU") return False if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op): @@ -545,13 +659,18 @@ class SupportedOperators: # check data type if ifm_tensor.dtype != ofm_tensor.dtype: + print("Warning:", op.type, "input type differs from output type, placing on CPU") return False if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16): + print( + "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type) + ) return False # check shape if ifm_tensor.shape != ofm_tensor.shape: + print("Warning:", op.type, "input shape differs from output shape, placing on CPU") return False return True @@ -560,12 +679,14 @@ class SupportedOperators: def check_bias_restrictions(cls, bias_tensor): # check data type if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64): + print("Warning: bias tensor datatype must be int32 or int64, placing on CPU") return False # check if values fits in 40-bit if bias_tensor is not None and bias_tensor.dtype == DataType.int64: for quant_value in bias_tensor.quant_values: if not (-(1 << 39) <= quant_value < (1 << 39)): + print("Warning: bias tensor values are larger than 40 bits, placing on CPU") return False return True diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 20d448d7..1fb452cf 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -30,11 +30,14 @@ support = SupportedOperators() def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets): + qp = QuantizationParameters() in0 = Tensor(in_shape, DataType.uint8, "in") - in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets) - in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets) - in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1]) + in0.quantization = qp + in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp) + in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp) + in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp) out = Tensor(out_shape, DataType.uint8, "out") + out.quantization = qp attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0} return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs) diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py index adb874a0..c5ff0033 100644 --- a/ethosu/vela/test/testutil.py +++ b/ethosu/vela/test/testutil.py @@ -22,6 +22,7 @@ from ethosu.vela.data_type import DataType from ethosu.vela.nn_graph import Subgraph from ethosu.vela.operation import Operation from ethosu.vela.tensor import create_const_tensor +from ethosu.vela.tensor import QuantizationParameters from ethosu.vela.tensor import Tensor @@ -38,7 +39,17 @@ def create_arch(): ) -def create_elemwise_op(type, name, ifm_shape, ifm2_shape, ofm_shape, datatype=DataType.uint8): +def create_elemwise_op( + type, + name, + ifm_shape, + ifm2_shape, + ofm_shape, + datatype=DataType.uint8, + ifm_quant=QuantizationParameters(), + ifm2_quant=QuantizationParameters(), + ofm_quant=QuantizationParameters(), +): # Creates elementwise operation with constant IFM/IFM2 if datatype.size_in_bytes() == 1: np_type = np.uint8 @@ -47,9 +58,16 @@ def create_elemwise_op(type, name, ifm_shape, ifm2_shape, ofm_shape, datatype=Da else: np_type = np.int32 op = Operation(type, name) - op.add_input_tensor(create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type)) - op.add_input_tensor(create_const_tensor(name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type)) + op.add_input_tensor( + create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type, quantization=ifm_quant) + ) + op.add_input_tensor( + create_const_tensor( + name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type, quantization=ifm2_quant + ) + ) ofm = Tensor(ofm_shape, datatype, name + "_ofm") + ofm.quantization = ofm_quant op.set_output_tensor(ofm) return op -- cgit v1.2.1