aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorwilisa01 <william.isaksson@arm.com>2023-04-13 17:05:09 +0000
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-05-04 08:45:46 +0000
commit0a7d5ee98dfc8c881372bc5a50be37aed209c30e (patch)
tree7b88b1cc4bae5fa835f16f0c0d51fe7d4e14a7af /ethosu/vela/tflite_reader.py
parent50550d60121a3ca39b086d643163e7c74ccee837 (diff)
downloadethos-u-vela-0a7d5ee98dfc8c881372bc5a50be37aed209c30e.tar.gz
MLBEDSW-7504: Vela does not keep op version number
We now read operator code version, store it in operator and write it out to optimized file. Signed-off-by: wilisa01 <william.isaksson@arm.com> Change-Id: Idba672531d2e2a0203a85d3ffca9cf65ace85b47
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 061f3626..85acb6b8 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -117,7 +117,7 @@ class TFLiteSubgraph:
return tens
def parse_operator(self, op_index, op_data):
- op_type, opt_serializer, custom_code, indices = self.graph.operator_codes[op_data.OpcodeIndex()]
+ op_type, opt_serializer, custom_code, indices, version = self.graph.operator_codes[op_data.OpcodeIndex()]
inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()]
outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()]
intermediates = []
@@ -130,6 +130,7 @@ class TFLiteSubgraph:
inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
op = Operation(op_type, name)
op.op_index = op_index
+ op.version = version
op.inputs = inputs
op.outputs = outputs
op.intermediates = intermediates
@@ -338,7 +339,7 @@ class TFLiteGraph:
custom_code = None
if c == BuiltinOperator.CUSTOM:
custom_code = decode_str(code.CustomCode())
- return op_type, ser, custom_code, indices
+ return op_type, ser, custom_code, indices, code.Version()
def read_tflite(filename, batch_size, feed_dict, output_node_names, initialisation_nodes):