aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py20
1 files changed, 8 insertions, 12 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 60bc6fd0..193a23ff 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -511,8 +511,8 @@ class TFLiteSupportedOperators:
"""The width and height of the IFM and OFM must match one of the following criteria:
IFM W and H must both be 1
IFM must match OFM
- OFM W and H must be 2x IFM -1, if align_corners is True
- OFM W and H must be 2x IFM, if align_corners is False"""
+ OFM W and H must be equal and 2/4/8x IFM -1, if align_corners is True
+ OFM W and H must be equal and 2/4/8x IFM, if align_corners is False"""
# Easier to start with False condition as very few cases result in a supported resize
valid = False
ifm_shape = op.ifm.shape
@@ -523,16 +523,12 @@ class TFLiteSupportedOperators:
if ((ifm_shape[1] == 1) and (ifm_shape[2] == 1)) or (ifm_shape == ofm_shape):
valid = True
else:
- upscaled_shape = np.array(ifm_shape[1:3])
- out_shape = np.array(ofm_shape[1:3])
- while (upscaled_shape < out_shape).all():
- upscaled_shape *= 2
- if align_corners:
- upscaled_shape -= 1
- # Valid if OFM is 2x IFM (-1 for align corners)
- if np.array_equal(out_shape, upscaled_shape):
- valid = True
- break
+ # Valid if OFM is 2/4/8x IFM (-1 for align corners)
+ w_upscale_factor = (ofm_shape[1] + 1) / ifm_shape[1] if align_corners else ofm_shape[1] / ifm_shape[1]
+ h_upscale_factor = (ofm_shape[2] + 1) / ifm_shape[2] if align_corners else ofm_shape[2] / ifm_shape[2]
+
+ valid = w_upscale_factor == h_upscale_factor and w_upscale_factor in [2, 4, 8]
+
return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and align_corners={align_corners}"
@staticmethod