aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 43ba36f0..fdf0c6b3 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -52,6 +52,7 @@ class SupportedOperators:
)
self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
+ self.binary_elem_wise_shift_ops = set(("SHL", "SHR",))
self.binary_elem_wise_add_mul_sub = set(
(
"AddAct",
@@ -63,11 +64,9 @@ class SupportedOperators:
"Mul",
"Add",
"Sub",
- "SHL",
- "SHR",
)
)
- self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub
+ self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
self.activation_ops = set(
(
@@ -153,7 +152,7 @@ class SupportedOperators:
return False
if (
t.element_size() > 2
- and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub
+ and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
):
return False
# check size
@@ -311,6 +310,11 @@ class SupportedOperators:
ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
):
return False
+ elif op.type in self.binary_elem_wise_shift_ops | set(("CLZ")):
+ if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
+ return False
+ if op.type in ("CLZ", "SHL") and ofm_tensor.dtype != DataType.int32:
+ return False
# check batch size
if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
@@ -365,8 +369,8 @@ class SupportedOperators:
if ifm_tensor.dtype != ofm_tensor.dtype:
return False
- if ifm_tensor.dtype != DataType.int16:
- return False # TODO: Implement support for 8-bit Softmax
+ if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
+ return False
# check batch size
if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1: