aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 65588bf4..ab7f2db7 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -15,6 +15,8 @@
# limitations under the License.
# Description:
# The SupportedOperators class which is a collection of all supported operators and parameter checks.
+import numpy as np
+
from .data_type import BaseType
from .data_type import DataType
@@ -287,13 +289,15 @@ class SupportedOperators:
return True
if op.inputs[0].shape == op.outputs[0].shape:
return True
- upscaled_shape = [op.inputs[0].shape[1] * 2, op.inputs[0].shape[2] * 2]
- out_shape = op.outputs[0].shape[1:3]
- if not op.attrs["align_corners"] and out_shape != upscaled_shape:
- return False
- elif op.attrs["align_corners"] and out_shape != [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
- return False
- return True
+ upscaled_shape = np.array(op.inputs[0].shape[1:3])
+ out_shape = np.array(op.outputs[0].shape[1:3])
+ while (upscaled_shape < out_shape).all():
+ upscaled_shape *= 2
+ if op.attrs["align_corners"]:
+ upscaled_shape -= 1
+ if np.array_equal(out_shape, upscaled_shape):
+ return True
+ return False
def check_vector_product_restrictions(self, op):
# check data type