diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 1bebe9af..99a4ba10 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -87,7 +87,7 @@ class SupportedOperators: set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops ) relu_ops = Op.op_set(Op.is_relu_op) - activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax,)) + activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish)) npu_post_ops = ( # activation functions activation_ops @@ -261,6 +261,10 @@ class SupportedOperators: self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant) self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm) + # HardSwish specific checks: + self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit) + self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in SupportedOperators.supported_operators: @@ -934,6 +938,13 @@ class SupportedOperators: return valid, f"Op has ofm_dtype={ofm_dtype}" @staticmethod + def constraint_input_8bit(op): + "IFM must be int8 or uint8" + ifm_dtype = op.ifm.dtype + valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8) + return valid, f"Op has ifm_dtype={ifm_dtype}" + + @staticmethod def constraint_matching_quantization_parameters(op): "Both Input quantization parameters must match OFM quantization parameters" valid = True |