diff options
author | Johan Alfven <johan.alfven@arm.com> | 2023-09-04 17:18:33 +0200 |
---|---|---|
committer | Johan Alfven <johan.alfven@arm.com> | 2023-09-12 13:07:47 +0200 |
commit | c0bb868fe375ff38eede8be363165794ca780978 (patch) | |
tree | c6c1c05695b2f19a3d4e5584e7987e86004683f3 /ethosu/vela/tflite_supported_operators.py | |
parent | 26c8e8416589f8f76f16f16483bb2d6aad036dfa (diff) | |
download | ethos-u-vela-c0bb868fe375ff38eede8be363165794ca780978.tar.gz |
MLBEDSW-7997: [MLCE] Extended stride support for TRANSPOSE CONV
- Support for stride WxH 1x1
- Support for stride WxH 2x1 when IFM and KERNEL
is 1D shape with height 1
- Added test to supported operators
- Updated SUPPORTED_OPS.md
Change-Id: Ic1abead8399a5e14a78d962f8aded0d3b3dbfcc4
Signed-off-by: Johan Alfven <johan.alfven@arm.com>X
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 723c5f22..52b04857 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -590,11 +590,24 @@ class TFLiteSupportedOperators: @staticmethod def constraint_tconv_stride(op): - "Stride values for both width and height must be 2" - w = op.kernel.stride.x - h = op.kernel.stride.y - valid = (w == 2) and (h == 2) - return valid, f"Op has stride WxH as: {w}x{h}" + """Stride values for width and height must match one of the following criteria: + Stride values WxH must be 1x1 or 2x2 + Stride WxH 2x1 supported if ifm height and kernel height = 1""" + s_w = op.kernel.stride.x + s_h = op.kernel.stride.y + k_h = op.kernel.height + i_h = op.ifm.shape[1] + valid = False + if s_w == 1 and s_h == 1: + valid = True + + if s_w == 2 and s_h == 2: + valid = True + + if s_w == 2 and s_h == 1 and i_h == 1 and k_h == 1: + valid = True + + return valid, f"Op has ifm_height={i_h}, kernel_height={k_h} and stride WxH as {s_w}x{s_h}" @staticmethod def constraint_tconv_same(op): |