diff options
author | Jacob Bohlin <jacob.bohlin@arm.com> | 2020-08-20 10:53:02 +0200 |
---|---|---|
committer | tim.hall <tim.hall@arm.com> | 2020-08-21 15:30:36 +0000 |
commit | 67e0d8f24fcb86115e834acd19dc57027b03ea4f (patch) | |
tree | 748a85cc9aca976b74e18d1e4bead38344c32922 /ethosu/vela/tflite_reader.py | |
parent | 1575b9413de2569de25bb2520b898a91f24ad3b0 (diff) | |
download | ethos-u-vela-67e0d8f24fcb86115e834acd19dc57027b03ea4f.tar.gz |
MLBEDSW-2663: Handle optional tensors
Includes a number of changes:
* Handle non-existing optional inputs
* Handle disabled optional inputs (-1 indexed)
* Added unit tests for parsing operators
* Add bias tensor to the different Convolutions + FullyConnected if
it's missing.
Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com>
Change-Id: Ib88d2b610314b1c886fc0aef4f9da87430ce6ae5
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index daa208f1..a2f744d3 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -137,8 +137,8 @@ class TFLiteSubgraph: def parse_operator(self, op_index, op_data): op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()] - inputs = [self.tensors[idx] for idx in op_data.InputsAsNumpy()] - outputs = [self.tensors[idx] for idx in op_data.OutputsAsNumpy()] + inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()] + outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()] name = "unknown_op_name" if len(outputs): name = outputs[0].name @@ -153,12 +153,19 @@ class TFLiteSubgraph: if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"): inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0)) - if not op.type.endswith("BackpropInput"): - inputs[2] = clone_and_reshape_tensor(inputs[2], (0,)) + if len(inputs) < 3 or (len(inputs) < 4 and "Backprop" in op_type): + # No Bias tensor + inputs.append(None) + if inputs[-1]: + inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,)) if op_type.startswith("FullyConnected"): inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0)) - inputs[2] = clone_and_reshape_tensor(inputs[2], (0,)) + if len(inputs) < 3: + # No Bias tensor + inputs.append(None) + if inputs[-1]: + inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,)) if opt_serializer is not None: op.attrs = opt_serializer.deserialize(op_data) |