aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 1bebe9af..99a4ba10 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -87,7 +87,7 @@ class SupportedOperators:
set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
)
relu_ops = Op.op_set(Op.is_relu_op)
- activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax,))
+ activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish))
npu_post_ops = (
# activation functions
activation_ops
@@ -261,6 +261,10 @@ class SupportedOperators:
self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant)
self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm)
+ # HardSwish specific checks:
+ self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit)
+ self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in SupportedOperators.supported_operators:
@@ -934,6 +938,13 @@ class SupportedOperators:
return valid, f"Op has ofm_dtype={ofm_dtype}"
@staticmethod
+ def constraint_input_8bit(op):
+ "IFM must be int8 or uint8"
+ ifm_dtype = op.ifm.dtype
+ valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
+ return valid, f"Op has ifm_dtype={ifm_dtype}"
+
+ @staticmethod
def constraint_matching_quantization_parameters(op):
"Both Input quantization parameters must match OFM quantization parameters"
valid = True