aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/nn_graph.py2
-rw-r--r--ethosu/vela/tflite_reader.py8
-rw-r--r--ethosu/vela/tflite_writer.py10
3 files changed, 15 insertions, 5 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 4a2855b2..bfab2270 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -502,7 +502,7 @@ class Graph:
self.name = name
self.batch_size = batch_size
self.subgraphs = []
-
+ self.metadata = []
self.memory_used = {}
self.bits_per_element = {}
self.total_size = {}
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 9346b760..bf3fe950 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -245,6 +245,14 @@ class TFLiteGraph:
sg.output_tensors = tflite_sg.outputs
self.nng.subgraphs.append(sg)
+ # Preserve the original metadata
+ for idx in range(model.MetadataLength()):
+ meta = model.Metadata(idx)
+ name = meta.Name()
+ if name is not None:
+ buf_data = self.buffers[meta.Buffer()]
+ self.nng.metadata.append((name, buf_data))
+
def parse_buffer(self, buf_data):
if buf_data.DataLength() == 0:
return None
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 3f3b7b1b..92b5c6b0 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -414,12 +414,14 @@ class TFLiteSerialiser:
if tens.mem_type in set((MemType.Scratch, MemType.Scratch_fast)) and tens.address is not None:
offsets[idx] = np.int32(tens.address)
- metadata_buffer = np.array([version, subgraph_idx, nbr_tensors] + offsets)
- self.buffers_to_write.append(metadata_buffer)
+ self.nng.metadata.append(("OfflineMemoryAllocation", np.array([version, subgraph_idx, nbr_tensors] + offsets)))
- buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write])
+ metadata_list = []
+ for name, buffer in self.nng.metadata:
+ self.buffers_to_write.append(buffer)
+ metadata_list.append((name, len(self.buffers_to_write) - 1))
- metadata_list = [("OfflineMemoryAllocation", len(self.buffers_to_write) - 1)]
+ buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write])
metadata_offset = self.write_offset_vector([self.serialise_metadata(metadata) for metadata in metadata_list])
Model.ModelStart(builder)