diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 43ba36f0..fdf0c6b3 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -52,6 +52,7 @@ class SupportedOperators: ) self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",)) self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",)) + self.binary_elem_wise_shift_ops = set(("SHL", "SHR",)) self.binary_elem_wise_add_mul_sub = set( ( "AddAct", @@ -63,11 +64,9 @@ class SupportedOperators: "Mul", "Add", "Sub", - "SHL", - "SHR", ) ) - self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub + self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops self.activation_ops = set( ( @@ -153,7 +152,7 @@ class SupportedOperators: return False if ( t.element_size() > 2 - and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub + and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops ): return False # check size @@ -311,6 +310,11 @@ class SupportedOperators: ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32 ): return False + elif op.type in self.binary_elem_wise_shift_ops | set(("CLZ")): + if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32: + return False + if op.type in ("CLZ", "SHL") and ofm_tensor.dtype != DataType.int32: + return False # check batch size if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1: @@ -365,8 +369,8 @@ class SupportedOperators: if ifm_tensor.dtype != ofm_tensor.dtype: return False - if ifm_tensor.dtype != DataType.int16: - return False # TODO: Implement support for 8-bit Softmax + if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16): + return False # check batch size if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1: |