diff options
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index daa208f1..a2f744d3 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -137,8 +137,8 @@ class TFLiteSubgraph: def parse_operator(self, op_index, op_data): op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()] - inputs = [self.tensors[idx] for idx in op_data.InputsAsNumpy()] - outputs = [self.tensors[idx] for idx in op_data.OutputsAsNumpy()] + inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()] + outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()] name = "unknown_op_name" if len(outputs): name = outputs[0].name @@ -153,12 +153,19 @@ class TFLiteSubgraph: if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"): inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0)) - if not op.type.endswith("BackpropInput"): - inputs[2] = clone_and_reshape_tensor(inputs[2], (0,)) + if len(inputs) < 3 or (len(inputs) < 4 and "Backprop" in op_type): + # No Bias tensor + inputs.append(None) + if inputs[-1]: + inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,)) if op_type.startswith("FullyConnected"): inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0)) - inputs[2] = clone_and_reshape_tensor(inputs[2], (0,)) + if len(inputs) < 3: + # No Bias tensor + inputs.append(None) + if inputs[-1]: + inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,)) if opt_serializer is not None: op.attrs = opt_serializer.deserialize(op_data) |