aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-08-20 10:53:02 +0200
committertim.hall <tim.hall@arm.com>2020-08-21 15:30:36 +0000
commit67e0d8f24fcb86115e834acd19dc57027b03ea4f (patch)
tree748a85cc9aca976b74e18d1e4bead38344c32922 /ethosu/vela/supported_operators.py
parent1575b9413de2569de25bb2520b898a91f24ad3b0 (diff)
downloadethos-u-vela-67e0d8f24fcb86115e834acd19dc57027b03ea4f.tar.gz
MLBEDSW-2663: Handle optional tensors
Includes a number of changes: * Handle non-existing optional inputs * Handle disabled optional inputs (-1 indexed) * Added unit tests for parsing operators * Add bias tensor to the different Convolutions + FullyConnected if it's missing. Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: Ib88d2b610314b1c886fc0aef4f9da87430ce6ae5
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py22
1 files changed, 16 insertions, 6 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 9e415b51..e6aaca31 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -131,22 +131,32 @@ class SupportedOperators:
def check_generic_restrictions(self, op):
# check fully defined shapes
for t in op.inputs:
+ if not t:
+ continue
if not t.has_fully_defined_shape():
print("Warning:", op.type, "has input(s) of undefined shape, placing on CPU")
return False
if t.shape == [] and op.type not in self.binary_elem_wise_main_ops:
- print("Warning:", op.type, "has input(s) of shape [].",
- "Scalar input or broadcasting is not supported for this operator,",
- "placing on CPU")
+ print(
+ "Warning:",
+ op.type,
+ "has input(s) of shape [].",
+ "Scalar input or broadcasting is not supported for this operator,",
+ "placing on CPU",
+ )
return False
for t in op.outputs:
if not t.has_fully_defined_shape():
print("Warning:", op.type, "has output(s) of undefined shape, placing on CPU")
return False
if t.shape == []:
- print("Warning:", op.type, "has output(s) of shape [].",
- "Scalar input or broadcasting is not supported for this operator,",
- "placing on CPU")
+ print(
+ "Warning:",
+ op.type,
+ "has output(s) of shape [].",
+ "Scalar input or broadcasting is not supported for this operator,",
+ "placing on CPU",
+ )
return False
# check data type