aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-08-20 10:53:02 +0200
committertim.hall <tim.hall@arm.com>2020-08-21 15:30:36 +0000
commit67e0d8f24fcb86115e834acd19dc57027b03ea4f (patch)
tree748a85cc9aca976b74e18d1e4bead38344c32922 /ethosu/vela/tflite_reader.py
parent1575b9413de2569de25bb2520b898a91f24ad3b0 (diff)
downloadethos-u-vela-67e0d8f24fcb86115e834acd19dc57027b03ea4f.tar.gz
MLBEDSW-2663: Handle optional tensors
Includes a number of changes: * Handle non-existing optional inputs * Handle disabled optional inputs (-1 indexed) * Added unit tests for parsing operators * Add bias tensor to the different Convolutions + FullyConnected if it's missing. Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: Ib88d2b610314b1c886fc0aef4f9da87430ce6ae5
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index daa208f1..a2f744d3 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -137,8 +137,8 @@ class TFLiteSubgraph:
def parse_operator(self, op_index, op_data):
op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()]
- inputs = [self.tensors[idx] for idx in op_data.InputsAsNumpy()]
- outputs = [self.tensors[idx] for idx in op_data.OutputsAsNumpy()]
+ inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()]
+ outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()]
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
@@ -153,12 +153,19 @@ class TFLiteSubgraph:
if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"):
inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
- if not op.type.endswith("BackpropInput"):
- inputs[2] = clone_and_reshape_tensor(inputs[2], (0,))
+ if len(inputs) < 3 or (len(inputs) < 4 and "Backprop" in op_type):
+ # No Bias tensor
+ inputs.append(None)
+ if inputs[-1]:
+ inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,))
if op_type.startswith("FullyConnected"):
inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
- inputs[2] = clone_and_reshape_tensor(inputs[2], (0,))
+ if len(inputs) < 3:
+ # No Bias tensor
+ inputs.append(None)
+ if inputs[-1]:
+ inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,))
if opt_serializer is not None:
op.attrs = opt_serializer.deserialize(op_data)