aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorHenrik G Olsson <henrik.olsson@arm.com>2021-03-23 17:34:49 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-04-06 11:59:35 +0000
commitea9b23c51345eb9e7108993f8fb2344f3c978b16 (patch)
treeade7c6f714fe61a90a3e5c576eefefc005155806 /ethosu/vela/tflite_reader.py
parentbb010eae126537010924b10d7ff8be4890dde184 (diff)
downloadethos-u-vela-ea9b23c51345eb9e7108993f8fb2344f3c978b16.tar.gz
MLBEDSW-4249 Hide stack traces in error messages
When faced with an invalid tflite file we now catch the exception to make it clear to the user that the issue is with the input and not with Vela, instead of just crashing. Same also applies to our own Vela error messages. Signed-off-by: Henrik G Olsson <henrik.olsson@arm.com> Change-Id: I56a81c5be9e1f46f3b98a88c6d24ee42fa0e450d
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py73
1 files changed, 45 insertions, 28 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index daea1bf8..b47177f7 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -16,6 +16,8 @@
# Description:
# Functions used to read from a TensorFlow Lite format file.
import os.path
+import struct
+import sys
import numpy as np
@@ -235,34 +237,49 @@ class TFLiteGraph:
with open(filename, "rb") as f:
buf = bytearray(f.read())
- model = Model.GetRootAsModel(buf, 0)
-
- self.buffers = []
- for idx in range(model.BuffersLength()):
- self.buffers.append(self.parse_buffer(model.Buffers(idx)))
-
- self.operator_codes = []
- for idx in range(model.OperatorCodesLength()):
- self.operator_codes.append(self.parse_operator_code(model.OperatorCodes(idx)))
-
- self.subgraphs = []
- for idx in range(model.SubgraphsLength()):
- self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx)))
-
- self.nng = Graph(self.name, self.batch_size)
- for tflite_sg in self.subgraphs:
- sg = Subgraph(tflite_sg.name)
- sg.original_inputs = tflite_sg.inputs # Preserve the original input order
- sg.output_tensors = tflite_sg.outputs
- self.nng.subgraphs.append(sg)
-
- # Preserve the original metadata
- for idx in range(model.MetadataLength()):
- meta = model.Metadata(idx)
- name = meta.Name()
- if name is not None:
- buf_data = self.buffers[meta.Buffer()]
- self.nng.metadata.append((name, buf_data))
+ try:
+ parsing_step = "parsing root"
+ model = Model.GetRootAsModel(buf, 0)
+
+ parsing_step = "parsing buffers length"
+ self.buffers = []
+ for idx in range(model.BuffersLength()):
+ parsing_step = f"parsing buffer {idx}"
+ self.buffers.append(self.parse_buffer(model.Buffers(idx)))
+
+ parsing_step = "parsing operator codes length"
+ self.operator_codes = []
+ for idx in range(model.OperatorCodesLength()):
+ parsing_step = f"parsing operator code {idx}"
+ self.operator_codes.append(self.parse_operator_code(model.OperatorCodes(idx)))
+
+ parsing_step = "parsing subgraphs length"
+ self.subgraphs = []
+ for idx in range(model.SubgraphsLength()):
+ parsing_step = f"parsing subgraph {idx}"
+ self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx)))
+
+ self.nng = Graph(self.name, self.batch_size)
+ for tflite_sg in self.subgraphs:
+ sg = Subgraph(tflite_sg.name)
+ sg.original_inputs = tflite_sg.inputs # Preserve the original input order
+ sg.output_tensors = tflite_sg.outputs
+ self.nng.subgraphs.append(sg)
+
+ parsing_step = "parsing metadata length"
+ # Preserve the original metadata
+ for idx in range(model.MetadataLength()):
+ parsing_step = f"parsing metadata {idx}"
+ meta = model.Metadata(idx)
+ parsing_step = f"parsing metadata name of metadata {idx}"
+ name = meta.Name()
+ if name is not None:
+ parsing_step = f"parsing metadata {idx} ({name})"
+ buf_data = self.buffers[meta.Buffer()]
+ self.nng.metadata.append((name, buf_data))
+ except (struct.error, TypeError, RuntimeError) as e:
+ print(f'Error: Invalid tflite file. Got "{e}" while {parsing_step}.')
+ sys.exit(1)
def parse_buffer(self, buf_data):
if buf_data.DataLength() == 0: