diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 60bc6fd0..193a23ff 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -511,8 +511,8 @@ class TFLiteSupportedOperators: """The width and height of the IFM and OFM must match one of the following criteria: IFM W and H must both be 1 IFM must match OFM - OFM W and H must be 2x IFM -1, if align_corners is True - OFM W and H must be 2x IFM, if align_corners is False""" + OFM W and H must be equal and 2/4/8x IFM -1, if align_corners is True + OFM W and H must be equal and 2/4/8x IFM, if align_corners is False""" # Easier to start with False condition as very few cases result in a supported resize valid = False ifm_shape = op.ifm.shape @@ -523,16 +523,12 @@ class TFLiteSupportedOperators: if ((ifm_shape[1] == 1) and (ifm_shape[2] == 1)) or (ifm_shape == ofm_shape): valid = True else: - upscaled_shape = np.array(ifm_shape[1:3]) - out_shape = np.array(ofm_shape[1:3]) - while (upscaled_shape < out_shape).all(): - upscaled_shape *= 2 - if align_corners: - upscaled_shape -= 1 - # Valid if OFM is 2x IFM (-1 for align corners) - if np.array_equal(out_shape, upscaled_shape): - valid = True - break + # Valid if OFM is 2/4/8x IFM (-1 for align corners) + w_upscale_factor = (ofm_shape[1] + 1) / ifm_shape[1] if align_corners else ofm_shape[1] / ifm_shape[1] + h_upscale_factor = (ofm_shape[2] + 1) / ifm_shape[2] if align_corners else ofm_shape[2] / ifm_shape[2] + + valid = w_upscale_factor == h_upscale_factor and w_upscale_factor in [2, 4, 8] + return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and align_corners={align_corners}" @staticmethod |