aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 723c5f2..52b0485 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -590,11 +590,24 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_tconv_stride(op):
- "Stride values for both width and height must be 2"
- w = op.kernel.stride.x
- h = op.kernel.stride.y
- valid = (w == 2) and (h == 2)
- return valid, f"Op has stride WxH as: {w}x{h}"
+ """Stride values for width and height must match one of the following criteria:
+ Stride values WxH must be 1x1 or 2x2
+ Stride WxH 2x1 supported if ifm height and kernel height = 1"""
+ s_w = op.kernel.stride.x
+ s_h = op.kernel.stride.y
+ k_h = op.kernel.height
+ i_h = op.ifm.shape[1]
+ valid = False
+ if s_w == 1 and s_h == 1:
+ valid = True
+
+ if s_w == 2 and s_h == 2:
+ valid = True
+
+ if s_w == 2 and s_h == 1 and i_h == 1 and k_h == 1:
+ valid = True
+
+ return valid, f"Op has ifm_height={i_h}, kernel_height={k_h} and stride WxH as {s_w}x{s_h}"
@staticmethod
def constraint_tconv_same(op):