diff options
author | Tim Hall <tim.hall@arm.com> | 2023-03-10 18:11:34 +0000 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2023-04-21 18:23:03 +0100 |
commit | 2180a172c31f27899d3bf77bfecccc1768667737 (patch) | |
tree | 2147c32ae3226599ba19077cd5f3ccf3ffb6a734 /ethosu/vela/tflite_mapping.py | |
parent | 7b3008a905d2a5122e21f945db7d2a2132473c53 (diff) | |
download | ethos-u-vela-2180a172c31f27899d3bf77bfecccc1768667737.tar.gz |
MLBEDSW-7408: MLCE: Crash when serialising model LSTM
- Added checking and reporting of missing operator attributes when
reading and writing TFLite file
- Added a TFLite semantic check to ensure that all required attribute
fields of builtin operators are read
- Added some sanity checks for RESHAPE operators that run on the
Ethos-U
- Stopped CPU operators from having their attributes modified
Change-Id: I05700681acdb09554f5945819717c08a9457295c
Signed-off-by: Tim Hall <tim.hall@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_mapping.py')
-rw-r--r-- | ethosu/vela/tflite_mapping.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index 98fe287d..14052ce5 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -418,10 +418,11 @@ class OptionsSerializer: def deserialize(self, op_data): builtin_options = op_data.BuiltinOptions() attrs = {} + attrs["attribute_read_error"] = [] # list of attributes that couldn't be read, empty indicates no error if builtin_options: tfattrs = self.cls() tfattrs.Init(builtin_options.Bytes, builtin_options.Pos) - for underscore_mem, camelcase_mem, deserialize, serialize, is_vector in self.members: + for underscore_mem, camelcase_mem, deserialize, _, is_vector in self.members: fun = camelcase_mem if is_vector: fun += "AsNumpy" @@ -430,15 +431,22 @@ class OptionsSerializer: try: attrs[underscore_mem] = deserialize(attr) except TypeError: - print("Warning: {0} could not read attribute '{1}'.".format(self.name, underscore_mem)) + attrs["attribute_read_error"].append(underscore_mem) + else: + # all builtin operators should have an options field + attrs["attribute_read_error"] = [underscore_mem for underscore_mem, *_ in self.members] return attrs def serialize(self, builder, attrs): ser_attrs = [] - for underscore_mem, camelcase_mem, deserialize, serialize, is_vector in self.members: - a = serialize(builder, attrs[underscore_mem]) - ser_attrs.append((camelcase_mem, a)) + attrs["attribute_write_error"] = [] # list of attributes that couldn't be read, empty indicates no error + for underscore_mem, camelcase_mem, _, serialize, _ in self.members: + try: + a = serialize(builder, attrs[underscore_mem]) + ser_attrs.append((camelcase_mem, a)) + except KeyError: + attrs["attribute_write_error"].append(underscore_mem) getattr(self.module, self.name + "Start")(builder) @@ -457,6 +465,8 @@ class CustomOptionsSerializer: def deserialize(self, op_data): attrs = {} + attrs["attribute_read_error"] = [] # list of attributes that couldn't be read, empty indicates no error + custom_options = op_data.CustomOptionsAsNumpy() attrs["custom_options"] = custom_options attrs["custom_options_format"] = op_data.CustomOptionsFormat() @@ -467,6 +477,8 @@ class CustomOptionsSerializer: return attrs def serialize(self, builder, attrs): + attrs["attribute_write_error"] = [] # list of attributes that couldn't be written, empty indicates no error + custom_type = attrs.get("custom_type", CustomType.ThirdPartyOp) self.custom_opt_format = attrs.get("custom_options_format", self.CUSTOM_OPTIONS_FORMAT_DEFAULT) |