diff options
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 5667aff5..9d312e52 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -20,6 +20,7 @@ import os.path import numpy as np from .errors import InputFileError +from .errors import TensorError from .nn_graph import Graph from .nn_graph import Subgraph from .operation import Operation @@ -69,14 +70,16 @@ class TFLiteSubgraph: self.tensors.append(self.parse_tensor(subgraph.Tensors(idx))) for idx in range(subgraph.OperatorsLength()): - self.parse_operator(subgraph.Operators(idx)) + self.parse_operator(idx, subgraph.Operators(idx)) - self.outputs = [self.tensors[idx] for idx in subgraph.OutputsAsNumpy()] - self.inputs = [self.tensors[idx] for idx in subgraph.InputsAsNumpy()] + self.outputs = self.get_tensors_from_indices_remove_duplicates(subgraph.OutputsAsNumpy(), "output") + self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input") # Fix up tensors without operations. Generate either Placeholder or Constant ops for tens in self.inputs: - assert not tens.ops + if tens.ops != []: + TensorError(tens, "This subgraph input tensor has unexpected driving operators.") + op = Operation("Placeholder", tens.name) op.outputs = [tens] tens.ops = [op] @@ -87,6 +90,21 @@ class TFLiteSubgraph: op.outputs = [tens] tens.ops = [op] + def get_tensors_from_indices_remove_duplicates(self, indices, warning_str): + tensors = [] + for idx in indices: + tensor = self.tensors[idx] + if tensor not in tensors: + tensors.append(tensor) + else: + print( + "Warning: Subgraph {0} tensor ({1}) with idx = {2} already seen. Removing the duplicate.".format( + warning_str, tensor, idx + ) + ) + + return tensors + def parse_tensor(self, tens_data): np_shape = tens_data.ShapeAsNumpy() shape = list(np_shape) if type(np_shape) is np.ndarray else [] @@ -121,7 +139,7 @@ class TFLiteSubgraph: tens.values = tens.quantization.dequantize(tens.quant_values) return tens - def parse_operator(self, op_data): + def parse_operator(self, op_index, op_data): op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()] inputs = [self.tensors[idx] for idx in op_data.InputsAsNumpy()] outputs = [self.tensors[idx] for idx in op_data.OutputsAsNumpy()] @@ -129,6 +147,7 @@ class TFLiteSubgraph: if len(outputs): name = outputs[0].name op = Operation(op_type, name) + op.op_index = op_index op.inputs = inputs op.outputs = outputs for out in op.outputs: @@ -143,7 +162,7 @@ class TFLiteSubgraph: inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0)) if opt_serializer is not None: - op.attrs = opt_serializer.deserialize(op_data.BuiltinOptions(), op_data.CustomOptionsAsNumpy()) + op.attrs = opt_serializer.deserialize(op_data) if "stride_w" in op.attrs: op.attrs["strides"] = (1, op.attrs["stride_h"], op.attrs["stride_w"], 1) |