aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2021-02-03 10:22:38 +0100
committerLouis Verhaard <louis.verhaard@arm.com>2021-02-03 13:22:56 +0100
commitc77615121c28409081d2ac6526694edebb8d7255 (patch)
tree15f234c73d7d0cb45b7d69c2576fc0cb04a46475
parent66d7ec060a9c97f2095cb65f7242fa044a505810 (diff)
downloadethos-u-vela-c77615121c28409081d2ac6526694edebb8d7255.tar.gz
MLBEDSW-3572: Fused activations must not be int32
Added supported operator check that 32-bit fused activation functions are not supported. Change-Id: I01fdafeff8fdb13c71eae4f63be7e6f81b9223df Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--ethosu/vela/supported_operators.py14
-rw-r--r--ethosu/vela/test/test_supported_operators.py10
2 files changed, 24 insertions, 0 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 505d4d1..8bb9c58 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -105,6 +105,7 @@ class SupportedOperators:
supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | pad_ops | npu_post_ops | memory_only_ops
# Supported data types
supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
+ supported_faf_dtypes = set((DataType.uint8, DataType.int8, DataType.int16))
supported_bias_dtypes = set((DataType.int32, DataType.int64))
supported_pad_dtypes = set((DataType.int32, DataType.int64))
# Defined ranges for allowed values:
@@ -135,6 +136,7 @@ class SupportedOperators:
self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
self.generic_constraints.append(SupportedOperators.constraint_tens_quant_per_axis)
self.generic_constraints.append(SupportedOperators.constraint_faf)
+ self.generic_constraints.append(SupportedOperators.constraint_faf_type)
self.generic_constraints.append(SupportedOperators.constraint_quant_scale_inf)
# Setup specific constraints. Note: the order matters
@@ -451,6 +453,18 @@ class SupportedOperators:
res = valid, f"Op has its fused activation function as: {faf}"
return res
+ @classmethod
+ @docstring_format_args([_list_formatter(supported_faf_dtypes)])
+ def constraint_faf_type(cls, op):
+ "If a fused activation function is present, the Output tensor must be one of type: {}"
+ if op.activation is None:
+ res = True, "Op has no fused activation function"
+ else:
+ valid = op.ofm.dtype in cls.supported_faf_dtypes
+ ext_type = optype_to_builtintype(op.activation.op_type)
+ res = valid, f"Op has fused activation function {ext_type}, and Output tensor data type: {op.ofm.dtype}"
+ return res
+
@staticmethod
def constraint_stride_type(op):
"Stride values for both width and height must be integer types"
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 5f64dd9..3e9724d 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -147,6 +147,16 @@ def test_constraint_faf():
assert not support.is_operator_supported(op)
+def test_constraint_faf_ofm_dtype():
+ # If fused activation function is present, OFM must be 8 or 16 bit
+ shp = [1, 8, 8, 8]
+ for dtype in [DataType.int8, DataType.uint8, DataType.int16, DataType.int32]:
+ op = testutil.create_elemwise_op(Op.Add, "op", shp, shp, shp, datatype=dtype)
+ op.activation = ActivationFunction(Op.Relu)
+ expected = dtype.size_in_bytes() <= 2
+ assert support.is_operator_supported(op) == expected, f"Data type: {dtype}"
+
+
def test_constraint_conv_pass():
# First test a simple conv passes
op = testutil.create_op_with_quant_tensors(Op.Conv2D, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])