aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_mapping.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2023-03-10 18:11:34 +0000
committerTim Hall <tim.hall@arm.com>2023-04-21 18:23:03 +0100
commit2180a172c31f27899d3bf77bfecccc1768667737 (patch)
tree2147c32ae3226599ba19077cd5f3ccf3ffb6a734 /ethosu/vela/tflite_mapping.py
parent7b3008a905d2a5122e21f945db7d2a2132473c53 (diff)
downloadethos-u-vela-2180a172c31f27899d3bf77bfecccc1768667737.tar.gz
MLBEDSW-7408: MLCE: Crash when serialising model LSTM
- Added checking and reporting of missing operator attributes when reading and writing TFLite file - Added a TFLite semantic check to ensure that all required attribute fields of builtin operators are read - Added some sanity checks for RESHAPE operators that run on the Ethos-U - Stopped CPU operators from having their attributes modified Change-Id: I05700681acdb09554f5945819717c08a9457295c Signed-off-by: Tim Hall <tim.hall@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_mapping.py')
-rw-r--r--ethosu/vela/tflite_mapping.py22
1 files changed, 17 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index 98fe287d..14052ce5 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -418,10 +418,11 @@ class OptionsSerializer:
def deserialize(self, op_data):
builtin_options = op_data.BuiltinOptions()
attrs = {}
+ attrs["attribute_read_error"] = [] # list of attributes that couldn't be read, empty indicates no error
if builtin_options:
tfattrs = self.cls()
tfattrs.Init(builtin_options.Bytes, builtin_options.Pos)
- for underscore_mem, camelcase_mem, deserialize, serialize, is_vector in self.members:
+ for underscore_mem, camelcase_mem, deserialize, _, is_vector in self.members:
fun = camelcase_mem
if is_vector:
fun += "AsNumpy"
@@ -430,15 +431,22 @@ class OptionsSerializer:
try:
attrs[underscore_mem] = deserialize(attr)
except TypeError:
- print("Warning: {0} could not read attribute '{1}'.".format(self.name, underscore_mem))
+ attrs["attribute_read_error"].append(underscore_mem)
+ else:
+ # all builtin operators should have an options field
+ attrs["attribute_read_error"] = [underscore_mem for underscore_mem, *_ in self.members]
return attrs
def serialize(self, builder, attrs):
ser_attrs = []
- for underscore_mem, camelcase_mem, deserialize, serialize, is_vector in self.members:
- a = serialize(builder, attrs[underscore_mem])
- ser_attrs.append((camelcase_mem, a))
+ attrs["attribute_write_error"] = [] # list of attributes that couldn't be read, empty indicates no error
+ for underscore_mem, camelcase_mem, _, serialize, _ in self.members:
+ try:
+ a = serialize(builder, attrs[underscore_mem])
+ ser_attrs.append((camelcase_mem, a))
+ except KeyError:
+ attrs["attribute_write_error"].append(underscore_mem)
getattr(self.module, self.name + "Start")(builder)
@@ -457,6 +465,8 @@ class CustomOptionsSerializer:
def deserialize(self, op_data):
attrs = {}
+ attrs["attribute_read_error"] = [] # list of attributes that couldn't be read, empty indicates no error
+
custom_options = op_data.CustomOptionsAsNumpy()
attrs["custom_options"] = custom_options
attrs["custom_options_format"] = op_data.CustomOptionsFormat()
@@ -467,6 +477,8 @@ class CustomOptionsSerializer:
return attrs
def serialize(self, builder, attrs):
+ attrs["attribute_write_error"] = [] # list of attributes that couldn't be written, empty indicates no error
+
custom_type = attrs.get("custom_type", CustomType.ThirdPartyOp)
self.custom_opt_format = attrs.get("custom_options_format", self.CUSTOM_OPTIONS_FORMAT_DEFAULT)