aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index e0ee6163..86cc3c07 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -420,8 +420,8 @@ class SupportedOperators:
if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
return False
- # check batch size
- if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1:
+ # check shape
+ if len(ifm_tensor.shape) > 4 or ifm_tensor.shape != ofm_tensor.shape:
return False
return True