aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
authorCharles Xu <charles.xu@arm.com>2020-08-06 12:17:26 +0200
committerCharles Xu <charles.xu@arm.com>2020-08-24 14:57:01 +0200
commit87c13507b4f44edff0c819aa2bb6f9966c85c841 (patch)
treeb7cb602b09ae92b74c77ad90f6f5f3675948e9b9 /ethosu/vela/supported_operators.py
parentb9fc33c194036973273604d5fd7af9e814133238 (diff)
downloadethos-u-vela-87c13507b4f44edff0c819aa2bb6f9966c85c841.tar.gz
MLBEDSW-2654: Convert Resizebilinear to a number of 2x2 pools
Signed-off-by: Charles Xu <charles.xu@arm.com> Change-Id: Ida307afc33cd7963bdeb505df400732a3efcc846
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 65588bf4..ab7f2db7 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -15,6 +15,8 @@
# limitations under the License.
# Description:
# The SupportedOperators class which is a collection of all supported operators and parameter checks.
+import numpy as np
+
from .data_type import BaseType
from .data_type import DataType
@@ -287,13 +289,15 @@ class SupportedOperators:
return True
if op.inputs[0].shape == op.outputs[0].shape:
return True
- upscaled_shape = [op.inputs[0].shape[1] * 2, op.inputs[0].shape[2] * 2]
- out_shape = op.outputs[0].shape[1:3]
- if not op.attrs["align_corners"] and out_shape != upscaled_shape:
- return False
- elif op.attrs["align_corners"] and out_shape != [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
- return False
- return True
+ upscaled_shape = np.array(op.inputs[0].shape[1:3])
+ out_shape = np.array(op.outputs[0].shape[1:3])
+ while (upscaled_shape < out_shape).all():
+ upscaled_shape *= 2
+ if op.attrs["align_corners"]:
+ upscaled_shape -= 1
+ if np.array_equal(out_shape, upscaled_shape):
+ return True
+ return False
def check_vector_product_restrictions(self, op):
# check data type