diff options
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 2325ff65..2f3192b7 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -41,6 +41,7 @@ from .tflite_mapping import builtin_operator_map from .tflite_mapping import DataType from .tflite_mapping import datatype_map from .tflite_mapping import datatype_map_numpy +from .tflite_mapping import optype_to_builtintype class TFLiteSubgraph: @@ -180,9 +181,11 @@ class TFLiteSubgraph: init_subgraph_index = op.attrs["init_subgraph_index"] op.attrs["subgraph"] = (self.graph.nng.subgraphs[init_subgraph_index],) - if op_type == Op.Reshape and "new_shape" not in op.attrs: - # Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape - op.attrs["new_shape"] = outputs[0].shape + if op_type == Op.Reshape: + if "new_shape" in op.attrs["attribute_read_error"] and len(inputs) > 1: + # the "new_shape" attribute is optional if the new_shape tensor (inputs[1]) is specified. therefore, + # remove the attribute read error + op.attrs["attribute_read_error"].remove("new_shape") if op_type == Op.Cast: # Cast op should have "in/out_data_type" attribs add if missing @@ -212,6 +215,14 @@ class TFLiteSubgraph: if custom_code is not None: op.attrs["custom_code"] = custom_code + # finally, report any missing attributes that could not be read during deserialize() + attribute_read_error = op.attrs["attribute_read_error"] + if len(attribute_read_error) != 0: + print( + f"Warning: Could not read the following attributes from {optype_to_builtintype(op.type)}" + f" '{op.name}' {opt_serializer.name} field: {', '.join(attribute_read_error)}" + ) + @staticmethod def len1_array_to_scalar(arr): # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in |