aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2020-08-13 10:02:53 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2020-08-19 07:27:56 +0200
commit597fd3f88397501d61855a327c9632fc1dab3f57 (patch)
tree2c42401ed06b82a0252663568154bf8472fd8024 /ethosu/vela/supported_operators.py
parent30cb47abd68b125d5c3fc315948187a9ea8dbd43 (diff)
downloadethos-u-vela-597fd3f88397501d61855a327c9632fc1dab3f57.tar.gz
[MLBEDSW-2657] Softmax uint8/int8
Added graph rewrite of Softmax for uint8/int8. Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com> Change-Id: Iecdd5d2cd3156a601b3313debba4a3562e6be5d7
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 43ba36f0..fdf0c6b3 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -52,6 +52,7 @@ class SupportedOperators:
)
self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
+ self.binary_elem_wise_shift_ops = set(("SHL", "SHR",))
self.binary_elem_wise_add_mul_sub = set(
(
"AddAct",
@@ -63,11 +64,9 @@ class SupportedOperators:
"Mul",
"Add",
"Sub",
- "SHL",
- "SHR",
)
)
- self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub
+ self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
self.activation_ops = set(
(
@@ -153,7 +152,7 @@ class SupportedOperators:
return False
if (
t.element_size() > 2
- and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub
+ and op.type not in set(("Requantize", "ReduceSum", "CLZ",)) | self.binary_elem_wise_add_mul_sub | self.binary_elem_wise_shift_ops
):
return False
# check size
@@ -311,6 +310,11 @@ class SupportedOperators:
ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
):
return False
+ elif op.type in self.binary_elem_wise_shift_ops | set(("CLZ")):
+ if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
+ return False
+ if op.type in ("CLZ", "SHL") and ofm_tensor.dtype != DataType.int32:
+ return False
# check batch size
if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
@@ -365,8 +369,8 @@ class SupportedOperators:
if ifm_tensor.dtype != ofm_tensor.dtype:
return False
- if ifm_tensor.dtype != DataType.int16:
- return False # TODO: Implement support for 8-bit Softmax
+ if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
+ return False
# check batch size
if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1: