aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2023-03-10 18:11:34 +0000
committerTim Hall <tim.hall@arm.com>2023-04-21 18:23:03 +0100
commit2180a172c31f27899d3bf77bfecccc1768667737 (patch)
tree2147c32ae3226599ba19077cd5f3ccf3ffb6a734 /ethosu/vela/tflite_reader.py
parent7b3008a905d2a5122e21f945db7d2a2132473c53 (diff)
downloadethos-u-vela-2180a172c31f27899d3bf77bfecccc1768667737.tar.gz
MLBEDSW-7408: MLCE: Crash when serialising model LSTM
- Added checking and reporting of missing operator attributes when reading and writing TFLite file - Added a TFLite semantic check to ensure that all required attribute fields of builtin operators are read - Added some sanity checks for RESHAPE operators that run on the Ethos-U - Stopped CPU operators from having their attributes modified Change-Id: I05700681acdb09554f5945819717c08a9457295c Signed-off-by: Tim Hall <tim.hall@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 2325ff65..2f3192b7 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -41,6 +41,7 @@ from .tflite_mapping import builtin_operator_map
from .tflite_mapping import DataType
from .tflite_mapping import datatype_map
from .tflite_mapping import datatype_map_numpy
+from .tflite_mapping import optype_to_builtintype
class TFLiteSubgraph:
@@ -180,9 +181,11 @@ class TFLiteSubgraph:
init_subgraph_index = op.attrs["init_subgraph_index"]
op.attrs["subgraph"] = (self.graph.nng.subgraphs[init_subgraph_index],)
- if op_type == Op.Reshape and "new_shape" not in op.attrs:
- # Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
- op.attrs["new_shape"] = outputs[0].shape
+ if op_type == Op.Reshape:
+ if "new_shape" in op.attrs["attribute_read_error"] and len(inputs) > 1:
+ # the "new_shape" attribute is optional if the new_shape tensor (inputs[1]) is specified. therefore,
+ # remove the attribute read error
+ op.attrs["attribute_read_error"].remove("new_shape")
if op_type == Op.Cast:
# Cast op should have "in/out_data_type" attribs add if missing
@@ -212,6 +215,14 @@ class TFLiteSubgraph:
if custom_code is not None:
op.attrs["custom_code"] = custom_code
+ # finally, report any missing attributes that could not be read during deserialize()
+ attribute_read_error = op.attrs["attribute_read_error"]
+ if len(attribute_read_error) != 0:
+ print(
+ f"Warning: Could not read the following attributes from {optype_to_builtintype(op.type)}"
+ f" '{op.name}' {opt_serializer.name} field: {', '.join(attribute_read_error)}"
+ )
+
@staticmethod
def len1_array_to_scalar(arr):
# The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in