diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2020-09-29 10:00:39 +0200 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-09-30 07:52:39 +0000 |
commit | 0f98b361288c71fca327969346db32de098c797b (patch) | |
tree | 8b2905a6e763832a0029179d655c481b14e0a8a1 /ethosu/vela/supported_operators.py | |
parent | 0265f402c7ae1e875470298b4130fcc2f7ab4e23 (diff) | |
download | ethos-u-vela-0f98b361288c71fca327969346db32de098c797b.tar.gz |
[MLBEDSW-2802] Fix 5D tensor crash
Fixed crash in networks with 5D tensors.
Fixed crash for (int32) tensors without quantization.
Added validity checks for concatenation.
Moved unfusing of activation function from tflite_reader to graph_optimiser.
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: Ib9ba8891dc95ef5491e15d0feedef44331a26393
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 0a1af829..eec1b900 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -152,6 +152,9 @@ class SupportedOperators: "placing on CPU", ) return False + if len(t.shape) > 4: + print("Warning:", op.type, "has input(s) of unsupported shape", t.shape, "placing on CPU") + return False for t in op.outputs: if not t.has_fully_defined_shape(): print("Warning:", op.type, "has output(s) of undefined shape, placing on CPU") @@ -165,6 +168,9 @@ class SupportedOperators: "placing on CPU", ) return False + if len(t.shape) > 4: + print("Warning:", op.type, "has output(s) of unsupported shape", t.shape, "placing on CPU") + return False # check data type tensors = [t for t in op.get_ifm_ifm2_weights_ofm() if t is not None] @@ -447,6 +453,25 @@ class SupportedOperators: if num_to_be_inferred > 1: print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU") return False + if op.type.find("Concat") != -1: + axis = op.attrs.get("axis", None) + if axis is None: + print("Warning:", op.type, "invalid or missing axis, placing on CPU") + return False + if axis < 0: + axis += len(op.inputs[0].shape) + if not 0 < axis < len(op.inputs[0].shape): + print("Warning:", op.type, "invalid axis", axis, ", placing on CPU") + return False + ofm = op.outputs[0] + ofm_dims = len(ofm.shape) + for ifm in op.inputs: + if len(ifm.shape) != ofm_dims: + return False + for i in range(ofm_dims): + if i != axis and ifm.shape[i] != ofm.shape[i]: + print("Warning:", op.type, "invalid ifm:", ifm.name, ifm.shape, "mismatch in dimension", i, ", placing on CPU") + return False return True |