aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 0a1af829..eec1b900 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -152,6 +152,9 @@ class SupportedOperators:
"placing on CPU",
)
return False
+ if len(t.shape) > 4:
+ print("Warning:", op.type, "has input(s) of unsupported shape", t.shape, "placing on CPU")
+ return False
for t in op.outputs:
if not t.has_fully_defined_shape():
print("Warning:", op.type, "has output(s) of undefined shape, placing on CPU")
@@ -165,6 +168,9 @@ class SupportedOperators:
"placing on CPU",
)
return False
+ if len(t.shape) > 4:
+ print("Warning:", op.type, "has output(s) of unsupported shape", t.shape, "placing on CPU")
+ return False
# check data type
tensors = [t for t in op.get_ifm_ifm2_weights_ofm() if t is not None]
@@ -447,6 +453,25 @@ class SupportedOperators:
if num_to_be_inferred > 1:
print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
return False
+ if op.type.find("Concat") != -1:
+ axis = op.attrs.get("axis", None)
+ if axis is None:
+ print("Warning:", op.type, "invalid or missing axis, placing on CPU")
+ return False
+ if axis < 0:
+ axis += len(op.inputs[0].shape)
+ if not 0 < axis < len(op.inputs[0].shape):
+ print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
+ return False
+ ofm = op.outputs[0]
+ ofm_dims = len(ofm.shape)
+ for ifm in op.inputs:
+ if len(ifm.shape) != ofm_dims:
+ return False
+ for i in range(ofm_dims):
+ if i != axis and ifm.shape[i] != ofm.shape[i]:
+ print("Warning:", op.type, "invalid ifm:", ifm.name, ifm.shape, "mismatch in dimension", i, ", placing on CPU")
+ return False
return True