aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2023-03-10 18:11:34 +0000
committerTim Hall <tim.hall@arm.com>2023-04-21 18:23:03 +0100
commit2180a172c31f27899d3bf77bfecccc1768667737 (patch)
tree2147c32ae3226599ba19077cd5f3ccf3ffb6a734 /ethosu/vela/tflite_writer.py
parent7b3008a905d2a5122e21f945db7d2a2132473c53 (diff)
downloadethos-u-vela-2180a172c31f27899d3bf77bfecccc1768667737.tar.gz
MLBEDSW-7408: MLCE: Crash when serialising model LSTM
- Added checking and reporting of missing operator attributes when reading and writing TFLite file - Added a TFLite semantic check to ensure that all required attribute fields of builtin operators are read - Added some sanity checks for RESHAPE operators that run on the Ethos-U - Stopped CPU operators from having their attributes modified Change-Id: I05700681acdb09554f5945819717c08a9457295c Signed-off-by: Tim Hall <tim.hall@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py44
1 files changed, 28 insertions, 16 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 32982298..8d44774b 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -40,6 +40,7 @@ from .tflite import Tensor
from .tflite_mapping import builtin_operator_inv_map
from .tflite_mapping import BuiltinOperator
from .tflite_mapping import datatype_inv_map
+from .tflite_mapping import optype_to_builtintype
# the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
@@ -310,25 +311,36 @@ class TFLiteSerialiser:
custom_opt_offset = None
if opt_serializer is not None:
attrs = dict(op.attrs)
- if "strides" in attrs:
- attrs["stride_h"] = attrs["strides"][1]
- attrs["stride_w"] = attrs["strides"][2]
- if "ksize" in attrs:
- attrs["filter_height"] = attrs["ksize"][1]
- attrs["filter_width"] = attrs["ksize"][2]
- if "dilation" in attrs:
- attrs["dilation_h_factor"] = attrs["dilation"][1]
- attrs["dilation_w_factor"] = attrs["dilation"][2]
- if "channel_multiplier" in attrs:
- attrs["depth_multiplier"] = attrs["channel_multiplier"]
- if "container" in attrs:
- attrs["container"] = builder.CreateString(attrs["container"])
- if "shared_name" in attrs:
- attrs["shared_name"] = builder.CreateString(attrs["shared_name"])
- attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
+ if op.run_on_npu:
+ if "strides" in attrs:
+ attrs["stride_h"] = attrs["strides"][1]
+ attrs["stride_w"] = attrs["strides"][2]
+ if "ksize" in attrs:
+ attrs["filter_height"] = attrs["ksize"][1]
+ attrs["filter_width"] = attrs["ksize"][2]
+ if "dilation" in attrs:
+ attrs["dilation_h_factor"] = attrs["dilation"][1]
+ attrs["dilation_w_factor"] = attrs["dilation"][2]
+ if "channel_multiplier" in attrs:
+ attrs["depth_multiplier"] = attrs["channel_multiplier"]
+ if "container" in attrs:
+ attrs["container"] = builder.CreateString(attrs["container"])
+ if "shared_name" in attrs:
+ attrs["shared_name"] = builder.CreateString(attrs["shared_name"])
+ attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
+ # report any missing attributes that could not be written during serialize().
+ # operators that have been created internally (i.e. not created as part of reading an input network) may not
+ # have the write error attribute
+ attribute_write_error = attrs.get("attribute_write_error", [])
+ if len(attribute_write_error) != 0:
+ print(
+ f"Warning: Could not write the following attributes to {optype_to_builtintype(op.type)}"
+ f" '{op.name}' {opt_serializer.name} field: {', '.join(attribute_write_error)}"
+ )
+
mutating_variable_inputs_offset = self.write_byte_vector([])
Operator.OperatorStart(builder)
Operator.OperatorAddOpcodeIndex(builder, op_idx)