diff options
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 73 |
1 files changed, 45 insertions, 28 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index daea1bf8..b47177f7 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -16,6 +16,8 @@ # Description: # Functions used to read from a TensorFlow Lite format file. import os.path +import struct +import sys import numpy as np @@ -235,34 +237,49 @@ class TFLiteGraph: with open(filename, "rb") as f: buf = bytearray(f.read()) - model = Model.GetRootAsModel(buf, 0) - - self.buffers = [] - for idx in range(model.BuffersLength()): - self.buffers.append(self.parse_buffer(model.Buffers(idx))) - - self.operator_codes = [] - for idx in range(model.OperatorCodesLength()): - self.operator_codes.append(self.parse_operator_code(model.OperatorCodes(idx))) - - self.subgraphs = [] - for idx in range(model.SubgraphsLength()): - self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx))) - - self.nng = Graph(self.name, self.batch_size) - for tflite_sg in self.subgraphs: - sg = Subgraph(tflite_sg.name) - sg.original_inputs = tflite_sg.inputs # Preserve the original input order - sg.output_tensors = tflite_sg.outputs - self.nng.subgraphs.append(sg) - - # Preserve the original metadata - for idx in range(model.MetadataLength()): - meta = model.Metadata(idx) - name = meta.Name() - if name is not None: - buf_data = self.buffers[meta.Buffer()] - self.nng.metadata.append((name, buf_data)) + try: + parsing_step = "parsing root" + model = Model.GetRootAsModel(buf, 0) + + parsing_step = "parsing buffers length" + self.buffers = [] + for idx in range(model.BuffersLength()): + parsing_step = f"parsing buffer {idx}" + self.buffers.append(self.parse_buffer(model.Buffers(idx))) + + parsing_step = "parsing operator codes length" + self.operator_codes = [] + for idx in range(model.OperatorCodesLength()): + parsing_step = f"parsing operator code {idx}" + self.operator_codes.append(self.parse_operator_code(model.OperatorCodes(idx))) + + parsing_step = "parsing subgraphs length" + self.subgraphs = [] + for idx in range(model.SubgraphsLength()): + parsing_step = f"parsing subgraph {idx}" + self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx))) + + self.nng = Graph(self.name, self.batch_size) + for tflite_sg in self.subgraphs: + sg = Subgraph(tflite_sg.name) + sg.original_inputs = tflite_sg.inputs # Preserve the original input order + sg.output_tensors = tflite_sg.outputs + self.nng.subgraphs.append(sg) + + parsing_step = "parsing metadata length" + # Preserve the original metadata + for idx in range(model.MetadataLength()): + parsing_step = f"parsing metadata {idx}" + meta = model.Metadata(idx) + parsing_step = f"parsing metadata name of metadata {idx}" + name = meta.Name() + if name is not None: + parsing_step = f"parsing metadata {idx} ({name})" + buf_data = self.buffers[meta.Buffer()] + self.nng.metadata.append((name, buf_data)) + except (struct.error, TypeError, RuntimeError) as e: + print(f'Error: Invalid tflite file. Got "{e}" while {parsing_step}.') + sys.exit(1) def parse_buffer(self, buf_data): if buf_data.DataLength() == 0: |