aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2020-09-28 15:53:40 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-10-12 07:02:26 +0000
commit8359a474e4f125382fd7b7d5431a612f6013f107 (patch)
tree50c5fc46f0b6117af77722a61edce19131b35e11
parent04f8c009d17e339d5afd515a57f98c31e4297fe8 (diff)
downloadethos-u-vela-8359a474e4f125382fd7b7d5431a612f6013f107.tar.gz
MLBEDSW-3061: Update supported_operators.py
This commit changes and amends some parts of the restriction functions in order to make sure operators are correctly placed. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: I336cf33a874c9078a5bbf81ce129ff917dbc5e9a
-rw-r--r--ethosu/vela/numeric_util.py10
-rw-r--r--ethosu/vela/supported_operators.py161
-rw-r--r--ethosu/vela/test/test_supported_operators.py9
-rw-r--r--ethosu/vela/test/testutil.py24
4 files changed, 178 insertions, 26 deletions
diff --git a/ethosu/vela/numeric_util.py b/ethosu/vela/numeric_util.py
index 4ebef8e5..20aa4a05 100644
--- a/ethosu/vela/numeric_util.py
+++ b/ethosu/vela/numeric_util.py
@@ -93,3 +93,13 @@ def full_shape(dim, shape, fill):
def overlaps(start1, end1, start2, end2):
return start1 < end2 and start2 < end1
+
+
+def is_integer(num):
+ if isinstance(num, (int, np.integer)):
+ return True
+ if type(num) is float and num.is_integer():
+ return True
+ if isinstance(num, np.inexact) and np.mod(num, 1) == 0:
+ return True
+ return False
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 3d4a09f3..357e7fe8 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -19,6 +19,7 @@ import numpy as np
from .data_type import BaseType
from .data_type import DataType
+from .numeric_util import is_integer
from .operation import get_slice_offsets
from .operation import Op
@@ -130,6 +131,7 @@ class SupportedOperators:
self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
self.generic_constraints.append(SupportedOperators.constraint_faf)
self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
+ self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
def is_operator_supported(self, op):
if op.type not in SupportedOperators.supported_operators:
@@ -235,36 +237,76 @@ class SupportedOperators:
extra.append("quantization.scale_f32={}".format(tens.quantization.scale_f32))
return valid, " ".join(extra)
+ @staticmethod
+ def constraint_tens_quant_none_check(op):
+ "Tensors must have quantization parameters"
+ valid = True
+ extra = []
+ tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+ for tens in tensors:
+ if tens.quantization is None:
+ valid = False
+ extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
+ return valid, ", ".join(extra)
+
@classmethod
def check_convolution_restrictions(cls, op):
# check stride
- if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
+ stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
+ if not is_integer(stride_w) or not is_integer(stride_h):
+ print("Warning:", op.type, "has non-integer stride, placing on CPU")
+ return False
+ if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
+ print(
+ "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
+ op.type, stride_w, stride_h
+ )
+ )
return False
# check dilation
dilation_w_factor = op.attrs.get("dilation_w_factor", 1)
dilation_h_factor = op.attrs.get("dilation_h_factor", 1)
- if dilation_w_factor > 2 or dilation_h_factor > 2:
+ if not is_integer(dilation_w_factor) or not is_integer(dilation_h_factor):
+ print("Warning:", op.type, "has non-integer dilation factor, placing on CPU")
+ return False
+ if not 1 <= dilation_w_factor <= 2 or not 1 <= dilation_h_factor <= 2:
+ print(
+ "Warning:",
+ op.type,
+ "has dilation factors ({}, {}), only factors in range [1, 2] are allowed. Placing on CPU".format(
+ dilation_w_factor, dilation_h_factor
+ ),
+ )
return False
# check data type
ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
if weight_tensor.element_size() > 1:
+ print("Warning: only 8-bit weights are supported, placing on CPU")
return False
if not cls.check_bias_restrictions(bias_tensor):
return False
# check kernel size [HWIO]
- dilated_weight_w = weight_tensor.shape[1] + (weight_tensor.shape[1] - 1) * (dilation_w_factor - 1)
- dilated_weight_h = weight_tensor.shape[0] + (weight_tensor.shape[0] - 1) * (dilation_h_factor - 1)
+ dilated_weight_w = (weight_tensor.shape[1] - 1) * dilation_w_factor + 1
+ dilated_weight_h = (weight_tensor.shape[0] - 1) * dilation_h_factor + 1
- if dilated_weight_w > 64 or dilated_weight_h > 64:
+ # kernel limits
+ if not 1 <= dilated_weight_h <= 64:
+ print("Warning:", op.type, "has kernel height outside of range [1, 64], placing on CPU")
+ return False
+ if not 1 <= dilated_weight_w * dilated_weight_h <= 64 * 64:
+ print(
+ "Warning: product of kernel width and height must be >= 1 and not exceed 64 * 64 ({}),".format(64 * 64),
+ "placing on CPU",
+ )
return False
# check non const weights
if weight_tensor.values is None:
- print("Warning:", op.type, "has non-const weights, placing on CPU")
+ print("Warning:", op.type, "has non-constant weights, placing on CPU")
return False
# check weight sums over [HWI]
@@ -274,10 +316,12 @@ class SupportedOperators:
totals = np.sum(np.absolute(weights), axis=(0, 1, 2))
if np.amax(totals) > 127 * 65536:
+ print("Warning: sum of weights exceeds 127 * 65536 ({}), placing on CPU".format(127 * 65536))
return False
# check batch size
if ifm_tensor.shape[0] != 1:
+ print("Warning: only batch sizes of 1 are supported, placing on CPU")
return False
return True
@@ -289,6 +333,11 @@ class SupportedOperators:
if op.attrs["depth_multiplier"] > 1 and not (
(ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
):
+ print(
+ "Warning: for depth multipliers > 1,",
+ "number of input channels must be 1 and number of output channels must be equal to depth multiplier.",
+ "Placing on CPU",
+ )
return False
return cls.check_convolution_restrictions(op)
@@ -296,7 +345,8 @@ class SupportedOperators:
def check_transpose_convolution_restrictions(cls, op):
# check stride
stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
- if stride_h != stride_w != 2:
+ if stride_h != 2 or stride_w != 2:
+ print("Warning: stride must be equal to 2, placing on CPU")
return False
# check output dimensions
@@ -305,12 +355,24 @@ class SupportedOperators:
ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
if op.attrs["padding"] == b"SAME":
if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
+ print(
+ "Warning: for",
+ op.type,
+ "using SAME padding, output dimensions must equal input dimensions multiplied by stride.",
+ "Placing on CPU",
+ )
return False
elif op.attrs["padding"] == b"VALID":
kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
):
+ print(
+ "Warning: for",
+ op.type,
+ "using VALID padding, output dimensions must equal input dimensions multiplied by stride,",
+ "minus difference between kernel size and stride. Placing on CPU",
+ )
return False
return cls.check_convolution_restrictions(op)
@@ -318,33 +380,56 @@ class SupportedOperators:
@classmethod
def check_pooling_restrictions(cls, op):
# check stride
- if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3:
+ stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
+ if not is_integer(stride_w) or not is_integer(stride_h):
+ print("Warning:", op.type, "has non-integer stride, placing on CPU")
+ return False
+ if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
+ print(
+ "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
+ op.type, stride_w, stride_h
+ )
+ )
return False
# check data type
ifm_tensor, ofm_tensor = op.get_ifm_ofm()
if ifm_tensor.dtype != ofm_tensor.dtype:
if op.type != Op.ReduceSum:
+ print("Warning: input data type doesn't match output data type, placing on CPU")
return False
# TODO: else check ReduceSum restrictions.
# check batch size
if ifm_tensor.shape[0] != 1:
+ print("Warning: input batch size must be 1, placing on CPU")
return False
- if op.type in cls.avg_pooling_ops:
- # check kernel size
- if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8):
+ # check kernel size
+ kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"]
+ if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME":
+ if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8:
+ print(
+ "Warning:",
+ op.type,
+ "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format(
+ kernel_w, kernel_h
+ ),
+ )
return False
- if op.attrs["padding"] == b"VALID" and (
- op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256
- ):
+ if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops:
+ if not 1 <= kernel_w * kernel_h <= 256 * 256:
+ print(
+ "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format(
+ 256 * 256
+ ),
+ "placing on CPU",
+ )
return False
-
- if op.type in cls.max_pooling_ops:
- # check kernel size (any padding)
- if op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256:
+ if not 1 <= kernel_h <= 256:
+ print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU")
return False
+
return True
@classmethod
@@ -368,8 +453,15 @@ class SupportedOperators:
@classmethod
def check_vector_product_restrictions(cls, op):
# check data type
- _, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
+ ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
if weight_tensor.element_size() > 1:
+ print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type))
+ return False
+
+ # check batch size
+ batch_sizes = {1, 2, 4, 8}
+ if ifm_tensor.shape[0] not in batch_sizes:
+ print("Warning: only batch sizes {} supported for {}, placing on CPU".format(batch_sizes, op.type))
return False
if not cls.check_bias_restrictions(bias_tensor):
@@ -391,43 +483,65 @@ class SupportedOperators:
op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
and ifm_tensor.dtype != ofm_tensor.dtype
):
+ print("Warning:", op.type, "must have same input and output datatype, placing on CPU")
return False
if op.type in cls.binary_elem_wise_add_mul_sub:
# both inputs must have same type
if ifm_tensor.dtype != ifm2_tensor.dtype:
+ print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU")
return False
# signed input check
if ifm_tensor.dtype.type & BaseType.Signed:
# output must be signed
if ofm_tensor.dtype.type & BaseType.Unsigned:
+ print("Warning: only signed output types supported for {}, placing on CPU".format(op.type))
return False
# and 8, 16 or 32-bit
- if ofm_tensor.element_size() not in (1, 2, 4):
+ bit_lengths = {8, 16, 32}
+ if ofm_tensor.element_size() * 8 not in bit_lengths:
+ print(
+ "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths)
+ )
return False
# unsigned input check, output must be same type or int32
if ifm_tensor.dtype.type & BaseType.Unsigned and not (
ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
):
+ print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU")
return False
elif op.type in cls.binary_elem_wise_shift_ops:
if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
+ print("Warning:", op.type, "input datatypes are not int32, placing on CPU")
return False
if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
+ print("Warning:", op.type, "output datatype is not int32, placing on CPU")
return False
# check batch size
if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
+ print(
+ "Warning:",
+ op.type,
+ "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
+ )
return False
if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
+ print(
+ "Warning:",
+ op.type,
+ "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
+ )
return False
# negative alpha values are not supported
if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
+ print("Warning:", op.type, "has negative alpha, placing on CPU")
return False
# check if ifm or ifm2 has ofm shape
if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
+ print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU")
return False
if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
@@ -545,13 +659,18 @@ class SupportedOperators:
# check data type
if ifm_tensor.dtype != ofm_tensor.dtype:
+ print("Warning:", op.type, "input type differs from output type, placing on CPU")
return False
if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
+ print(
+ "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type)
+ )
return False
# check shape
if ifm_tensor.shape != ofm_tensor.shape:
+ print("Warning:", op.type, "input shape differs from output shape, placing on CPU")
return False
return True
@@ -560,12 +679,14 @@ class SupportedOperators:
def check_bias_restrictions(cls, bias_tensor):
# check data type
if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
+ print("Warning: bias tensor datatype must be int32 or int64, placing on CPU")
return False
# check if values fits in 40-bit
if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
for quant_value in bias_tensor.quant_values:
if not (-(1 << 39) <= quant_value < (1 << 39)):
+ print("Warning: bias tensor values are larger than 40 bits, placing on CPU")
return False
return True
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 20d448d7..1fb452cf 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -30,11 +30,14 @@ support = SupportedOperators()
def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
+ qp = QuantizationParameters()
in0 = Tensor(in_shape, DataType.uint8, "in")
- in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets)
- in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets)
- in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1])
+ in0.quantization = qp
+ in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
+ in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
+ in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
out = Tensor(out_shape, DataType.uint8, "out")
+ out.quantization = qp
attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index adb874a0..c5ff0033 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -22,6 +22,7 @@ from ethosu.vela.data_type import DataType
from ethosu.vela.nn_graph import Subgraph
from ethosu.vela.operation import Operation
from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import QuantizationParameters
from ethosu.vela.tensor import Tensor
@@ -38,7 +39,17 @@ def create_arch():
)
-def create_elemwise_op(type, name, ifm_shape, ifm2_shape, ofm_shape, datatype=DataType.uint8):
+def create_elemwise_op(
+ type,
+ name,
+ ifm_shape,
+ ifm2_shape,
+ ofm_shape,
+ datatype=DataType.uint8,
+ ifm_quant=QuantizationParameters(),
+ ifm2_quant=QuantizationParameters(),
+ ofm_quant=QuantizationParameters(),
+):
# Creates elementwise operation with constant IFM/IFM2
if datatype.size_in_bytes() == 1:
np_type = np.uint8
@@ -47,9 +58,16 @@ def create_elemwise_op(type, name, ifm_shape, ifm2_shape, ofm_shape, datatype=Da
else:
np_type = np.int32
op = Operation(type, name)
- op.add_input_tensor(create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type))
- op.add_input_tensor(create_const_tensor(name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type))
+ op.add_input_tensor(
+ create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type, quantization=ifm_quant)
+ )
+ op.add_input_tensor(
+ create_const_tensor(
+ name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type, quantization=ifm2_quant
+ )
+ )
ofm = Tensor(ofm_shape, datatype, name + "_ofm")
+ ofm.quantization = ofm_quant
op.set_output_tensor(ofm)
return op