aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py73
1 files changed, 45 insertions, 28 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index daea1bf8..b47177f7 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -16,6 +16,8 @@
# Description:
# Functions used to read from a TensorFlow Lite format file.
import os.path
+import struct
+import sys
import numpy as np
@@ -235,34 +237,49 @@ class TFLiteGraph:
with open(filename, "rb") as f:
buf = bytearray(f.read())
- model = Model.GetRootAsModel(buf, 0)
-
- self.buffers = []
- for idx in range(model.BuffersLength()):
- self.buffers.append(self.parse_buffer(model.Buffers(idx)))
-
- self.operator_codes = []
- for idx in range(model.OperatorCodesLength()):
- self.operator_codes.append(self.parse_operator_code(model.OperatorCodes(idx)))
-
- self.subgraphs = []
- for idx in range(model.SubgraphsLength()):
- self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx)))
-
- self.nng = Graph(self.name, self.batch_size)
- for tflite_sg in self.subgraphs:
- sg = Subgraph(tflite_sg.name)
- sg.original_inputs = tflite_sg.inputs # Preserve the original input order
- sg.output_tensors = tflite_sg.outputs
- self.nng.subgraphs.append(sg)
-
- # Preserve the original metadata
- for idx in range(model.MetadataLength()):
- meta = model.Metadata(idx)
- name = meta.Name()
- if name is not None:
- buf_data = self.buffers[meta.Buffer()]
- self.nng.metadata.append((name, buf_data))
+ try:
+ parsing_step = "parsing root"
+ model = Model.GetRootAsModel(buf, 0)
+
+ parsing_step = "parsing buffers length"
+ self.buffers = []
+ for idx in range(model.BuffersLength()):
+ parsing_step = f"parsing buffer {idx}"
+ self.buffers.append(self.parse_buffer(model.Buffers(idx)))
+
+ parsing_step = "parsing operator codes length"
+ self.operator_codes = []
+ for idx in range(model.OperatorCodesLength()):
+ parsing_step = f"parsing operator code {idx}"
+ self.operator_codes.append(self.parse_operator_code(model.OperatorCodes(idx)))
+
+ parsing_step = "parsing subgraphs length"
+ self.subgraphs = []
+ for idx in range(model.SubgraphsLength()):
+ parsing_step = f"parsing subgraph {idx}"
+ self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx)))
+
+ self.nng = Graph(self.name, self.batch_size)
+ for tflite_sg in self.subgraphs:
+ sg = Subgraph(tflite_sg.name)
+ sg.original_inputs = tflite_sg.inputs # Preserve the original input order
+ sg.output_tensors = tflite_sg.outputs
+ self.nng.subgraphs.append(sg)
+
+ parsing_step = "parsing metadata length"
+ # Preserve the original metadata
+ for idx in range(model.MetadataLength()):
+ parsing_step = f"parsing metadata {idx}"
+ meta = model.Metadata(idx)
+ parsing_step = f"parsing metadata name of metadata {idx}"
+ name = meta.Name()
+ if name is not None:
+ parsing_step = f"parsing metadata {idx} ({name})"
+ buf_data = self.buffers[meta.Buffer()]
+ self.nng.metadata.append((name, buf_data))
+ except (struct.error, TypeError, RuntimeError) as e:
+ print(f'Error: Invalid tflite file. Got "{e}" while {parsing_step}.')
+ sys.exit(1)
def parse_buffer(self, buf_data):
if buf_data.DataLength() == 0: