aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2020-08-17 11:56:10 +0200
committertim.hall <tim.hall@arm.com>2020-08-19 11:28:35 +0000
commit2573311dbed6728fddba8bcb0cd03c35c45bbc03 (patch)
tree25095cd37a90bb81a7f5af1583a62fc2642f6619
parentfa34c6f4eebec25814c1e620a85721416f4d4ce3 (diff)
downloadethos-u-vela-2573311dbed6728fddba8bcb0cd03c35c45bbc03.tar.gz
MLBEDSW-2729: Add restrictions for shapeless tensors
Vela often fails when encountering operators that have inputs or outputs with shape == []. Only for elementwise ops where shape is broadcasted from IFM2 to IFM1 is this supported. This commit adds a restriction which places ops with shape [] tensors on the CPU except in the special case of broadcasting for elemwise ops. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: I5b0855233e3b83870209f4da00fb2dbd0184fee0
-rw-r--r--ethosu/vela/supported_operators.py18
1 files changed, 16 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index fdf0c6b3..c4186018 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -138,9 +138,23 @@ class SupportedOperators:
def check_generic_restrictions(self, op):
# check fully defined shapes
- for t in op.inputs + op.outputs:
+ for t in op.inputs:
if not t.has_fully_defined_shape():
- print("Warning:", op, "has inputs/outputs of undefined shape, placing on CPU")
+ print("Warning:", op.type, "has input(s) of undefined shape, placing on CPU")
+ return False
+ if t.shape == [] and op.type not in self.binary_elem_wise_main_ops:
+ print("Warning:", op.type, "has input(s) of shape [].",
+ "Scalar input or broadcasting is not supported for this operator,",
+ "placing on CPU")
+ return False
+ for t in op.outputs:
+ if not t.has_fully_defined_shape():
+ print("Warning:", op.type, "has output(s) of undefined shape, placing on CPU")
+ return False
+ if t.shape == []:
+ print("Warning:", op.type, "has output(s) of shape [].",
+ "Scalar input or broadcasting is not supported for this operator,",
+ "placing on CPU")
return False
# check data type