From 189f748e1a79ed88044efbe7137963bca830cbb5 Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Tue, 26 Jan 2021 12:12:51 +0100 Subject: MLBEDSW-3224: Support HardSwish Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f Signed-off-by: Diqing Zhong --- ethosu/vela/supported_operators.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'ethosu/vela/supported_operators.py') diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 1bebe9af..99a4ba10 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -87,7 +87,7 @@ class SupportedOperators: set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops ) relu_ops = Op.op_set(Op.is_relu_op) - activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax,)) + activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish)) npu_post_ops = ( # activation functions activation_ops @@ -261,6 +261,10 @@ class SupportedOperators: self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant) self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm) + # HardSwish specific checks: + self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit) + self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in SupportedOperators.supported_operators: @@ -933,6 +937,13 @@ class SupportedOperators: valid = ofm_dtype == DataType.int32 return valid, f"Op has ofm_dtype={ofm_dtype}" + @staticmethod + def constraint_input_8bit(op): + "IFM must be int8 or uint8" + ifm_dtype = op.ifm.dtype + valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8) + return valid, f"Op has ifm_dtype={ifm_dtype}" + @staticmethod def constraint_matching_quantization_parameters(op): "Both Input quantization parameters must match OFM quantization parameters" -- cgit v1.2.1