aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
authorRickard Bolin <rickard.bolin@arm.com>2024-01-31 08:42:00 +0000
committerRickard Bolin <rickard.bolin@arm.com>2024-02-06 14:10:31 +0100
commit646314ef1ee268cb972f3c918a49bff85748a332 (patch)
tree119132bd62c75f90dd5a728f2e5e545454fcbcfb /ethosu
parentfdbb072dacae339dd3f8efd3fb70fa84b9296033 (diff)
downloadethos-u-vela-646314ef1ee268cb972f3c918a49bff85748a332.tar.gz
MLBEDSW-8620: Fix MirrorPad supported ops check
Change-Id: I1458009f4b92c1a599efa3a63d6768148e55606d Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/tflite_supported_operators.py5
-rw-r--r--ethosu/vela/vela.py4
2 files changed, 5 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index ad61fcab..91a3ee83 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -831,10 +831,11 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_mirror_pad_padding_values(op):
"The number of pad values for each direction must not be larger than the ifm size in that dimension"
+ valid = True
pad_tensor = op.inputs[1].values
ifm_shape = op.inputs[0].shape
- for dim_padding, ifm_dim_shape in enumerate(pad_tensor, ifm_shape):
- if any(dim_padding > ifm_dim_shape):
+ for dim_padding, ifm_dim_shape in zip(pad_tensor, ifm_shape):
+ if any([pad_val > ifm_dim_shape for pad_val in dim_padding]):
valid = False
return valid, f"IFM shape: {ifm_shape}, number of padding values per dimension: {pad_tensor}"
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 66a21046..a4b93f10 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -178,7 +178,7 @@ def print_subgraph_io_summary(nng):
def generate_license():
lines = [
"<!--",
- "SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>",
+ "SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>",
"",
"SPDX-License-Identifier: Apache-2.0",
"",