aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py93
1 files changed, 92 insertions, 1 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 8b759be..a82f812 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -75,6 +75,8 @@ class SupportedOperators:
| resizing_ops
# FC layers
| fc_vector_products
+ # Mean (converts to depthwise conv)
+ | set((Op.Mean,))
)
unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
@@ -99,7 +101,7 @@ class SupportedOperators:
split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
memory_only_ops = set((Op.Reshape, Op.QuantizedReshape,)) | concat_ops | split_ops
- shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
+ shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean))
per_axis_quant_ops = convolution_like_ops # per-axis/channel quantization only currently supported for conv ops
supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | pad_ops | npu_post_ops | memory_only_ops
@@ -118,6 +120,8 @@ class SupportedOperators:
filter_range = (1, 8)
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
+ mean_kernel_product = 64 * 64
+ mean_kernel_product_int8 = 16 * 16
# Supported consumers
supported_pad_consumers = convolution_ops | depthwise_convolution_ops | pooling_ops
@@ -268,6 +272,13 @@ class SupportedOperators:
# HardSwish specific checks:
self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit)
self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types)
+ # Mean specific checks:
+ self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_input_8bit)
+ self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_properties)
+ self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_input_dims)
+ self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_axis)
+ self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product)
+ self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_int8)
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
@@ -1077,3 +1088,83 @@ class SupportedOperators:
if op.attrs.get("keep_num_dims"):
valid = len(op.ifm.shape) == len(op.ofm.shape)
return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
+
+ def constraint_mean_input_dims(op):
+ "Input tensor must be at least 2D"
+ dims = len(op.inputs[0].shape)
+ return 2 <= dims <= 4, f"Input is {dims}D"
+
+ @staticmethod
+ def constraint_mean_axis(op):
+ "Axis indices must correspond to height and width axes"
+ dims = len(op.inputs[0].shape)
+ axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ if dims == 2 or dims == 3:
+ valid = axis in (0, 1, [0, 1], [1, 0])
+ elif dims == 4:
+ valid = axis in (1, 2, [1, 2], [2, 1])
+ return valid, f"Axis is {axis}"
+
+ @classmethod
+ @docstring_format_args([mean_kernel_product])
+ def constraint_mean_height_width_product(cls, op):
+ "Product of height and width can be at most {}"
+ shape = op.inputs[0].shape
+ hi = 0 if len(shape) < 4 else 1
+ h, w = shape[hi : hi + 2]
+ max_prod = cls.mean_kernel_product
+ return h * w <= max_prod, f"Product of height and width is {h * w}"
+
+ @classmethod
+ @docstring_format_args([mean_kernel_product_int8])
+ def constraint_mean_height_width_product_int8(cls, op):
+ """Product of IFM height and width can be at most {} when the following are true:
+ IFM dimensions are 4,
+ Axis indices are 1 and 2,
+ keep_dims is set to True and
+ IFM datatype is int8"""
+ shape = op.ifm.shape
+ axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ if (
+ len(shape) != 4
+ or op.ifm.dtype != DataType.int8
+ or not op.attrs.get("keep_dims")
+ or axis not in ([1, 2], [2, 1])
+ ):
+ return True, ""
+ hi = 0 if len(shape) < 4 else 1
+ h, w = shape[hi : hi + 2]
+ max_prod = cls.mean_kernel_product_int8
+ return h * w <= max_prod, f"Product of height and width is {h * w}"
+
+ @staticmethod
+ def constraint_mean_properties(op):
+ """Every constraint in either one (or both) of the following sets of constraints must be fulfilled:
+ Set A:
+ IFM dimensions are 4,
+ Axis indices are 1 and 2,
+ keep_dims is set to True
+ Set B:
+ IFM zero point and OFM zero point are the same,
+ IFM scale and OFM scale are the same"""
+ seta, setb = True, True
+ extra = []
+ axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ if len(op.ifm.shape) != 4:
+ seta = False
+ extra.append(f"IFM shape is {op.ifm.shape}")
+ if not any(np.array_equal(axis, ax) for ax in ([1, 2], [2, 1])):
+ seta = False
+ extra.append(f"Axis is {axis}")
+ if not op.attrs.get("keep_dims"):
+ seta = False
+ extra.append("keep_dims is False")
+ ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
+ if ifmq.zero_point != ofmq.zero_point:
+ setb = False
+ extra.append("IFM zero point does not match OFM zero point")
+ if ifmq.scale_f32 != ofmq.scale_f32:
+ setb = False
+ extra.append("IFM scale does not match OFM scale")
+ extra = ", ".join(extra)
+ return seta or setb, f"The following constraints were not fulfilled: {extra}"