diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2020-09-04 09:46:17 +0200 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-09-07 06:23:54 +0000 |
commit | 835d8e10f33f411664cebe65d3f6a872f6cc849a (patch) | |
tree | 60c0cfd477cac7b604ade275ec82f54cbb14e9e8 /ethosu/vela/supported_operators.py | |
parent | e5cf95b8c3de4e1e4cbc7046cafd4d84c7492596 (diff) | |
download | ethos-u-vela-835d8e10f33f411664cebe65d3f6a872f6cc849a.tar.gz |
[MLBEDSW-2928] Add batching to softmax
Added batching to softmax by reshaping the input.
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I0b516f9bf2410fb86372b229beba4a7280c498cc
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index e0ee6163..86cc3c07 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -420,8 +420,8 @@ class SupportedOperators: if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16): return False - # check batch size - if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1: + # check shape + if len(ifm_tensor.shape) > 4 or ifm_tensor.shape != ofm_tensor.shape: return False return True |