aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-09-04 17:18:33 +0200
committerJohan Alfven <johan.alfven@arm.com>2023-09-12 13:07:47 +0200
commitc0bb868fe375ff38eede8be363165794ca780978 (patch)
treec6c1c05695b2f19a3d4e5584e7987e86004683f3 /ethosu/vela/tflite_supported_operators.py
parent26c8e8416589f8f76f16f16483bb2d6aad036dfa (diff)
downloadethos-u-vela-c0bb868fe375ff38eede8be363165794ca780978.tar.gz
MLBEDSW-7997: [MLCE] Extended stride support for TRANSPOSE CONV
- Support for stride WxH 1x1 - Support for stride WxH 2x1 when IFM and KERNEL is 1D shape with height 1 - Added test to supported operators - Updated SUPPORTED_OPS.md Change-Id: Ic1abead8399a5e14a78d962f8aded0d3b3dbfcc4 Signed-off-by: Johan Alfven <johan.alfven@arm.com>X
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 723c5f22..52b04857 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -590,11 +590,24 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_tconv_stride(op):
- "Stride values for both width and height must be 2"
- w = op.kernel.stride.x
- h = op.kernel.stride.y
- valid = (w == 2) and (h == 2)
- return valid, f"Op has stride WxH as: {w}x{h}"
+ """Stride values for width and height must match one of the following criteria:
+ Stride values WxH must be 1x1 or 2x2
+ Stride WxH 2x1 supported if ifm height and kernel height = 1"""
+ s_w = op.kernel.stride.x
+ s_h = op.kernel.stride.y
+ k_h = op.kernel.height
+ i_h = op.ifm.shape[1]
+ valid = False
+ if s_w == 1 and s_h == 1:
+ valid = True
+
+ if s_w == 2 and s_h == 2:
+ valid = True
+
+ if s_w == 2 and s_h == 1 and i_h == 1 and k_h == 1:
+ valid = True
+
+ return valid, f"Op has ifm_height={i_h}, kernel_height={k_h} and stride WxH as {s_w}x{s_h}"
@staticmethod
def constraint_tconv_same(op):