aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
authorwilisa01 <william.isaksson@arm.com>2023-04-13 17:05:09 +0000
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-05-04 08:45:46 +0000
commit0a7d5ee98dfc8c881372bc5a50be37aed209c30e (patch)
tree7b88b1cc4bae5fa835f16f0c0d51fe7d4e14a7af /ethosu/vela/tflite_writer.py
parent50550d60121a3ca39b086d643163e7c74ccee837 (diff)
downloadethos-u-vela-0a7d5ee98dfc8c881372bc5a50be37aed209c30e.tar.gz
MLBEDSW-7504: Vela does not keep op version number
We now read operator code version, store it in operator and write it out to optimized file. Signed-off-by: wilisa01 <william.isaksson@arm.com> Change-Id: Idba672531d2e2a0203a85d3ffca9cf65ace85b47
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 2e7345ce..8c03f051 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -108,8 +108,8 @@ class TFLiteSerialiser:
if inp is not None and inp.src_tensor is not None:
op.inputs[idx] = inp.src_tensor
- # list of tuple(Op, string); the custom code is only used for 3rd party custom operators
- self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
+ # list of tuple(Op, string, op.version); the custom code is only used for 3rd party custom operators
+ self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", ""), op.version) for op in all_ops))
self.operator_code_map = {}
def align_nng_inputs_to_tflite(self, op):
@@ -176,7 +176,7 @@ class TFLiteSerialiser:
return buffer_map
- def serialise_operator_code(self, idx, op_type, custom_code):
+ def serialise_operator_code(self, idx, op_type, custom_code, version):
builder = self.builder
custom_code_offset = None
if op_type == Op.Custom:
@@ -207,6 +207,7 @@ class TFLiteSerialiser:
OperatorCode.OperatorCodeStart(builder)
OperatorCode.OperatorCodeAddDeprecatedBuiltinCode(builder, tf_code if tf_code < 127 else 127)
OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
+ OperatorCode.OperatorCodeAddVersion(builder, version)
if custom_code_offset is not None:
OperatorCode.OperatorCodeAddCustomCode(builder, custom_code_offset)
@@ -455,7 +456,10 @@ class TFLiteSerialiser:
def serialise_model(self):
builder = self.builder
operator_code_offset = self.write_offset_vector(
- [self.serialise_operator_code(idx, optype, code) for idx, (optype, code) in enumerate(self.operator_codes)]
+ [
+ self.serialise_operator_code(idx, optype, code, version)
+ for idx, (optype, code, version) in enumerate(self.operator_codes)
+ ]
)
description = builder.CreateString("Vela Optimised")