aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-08-07 16:21:03 +0100
committerFredrik Knutsson <fredrik.knutsson.hunnebo@gmail.com>2020-08-12 06:25:23 +0000
commit22f74e1c39572f084ad05cc2f208446fd2f50138 (patch)
tree96d602818de364e2297fb67c2f00fc1245ba74f7
parente99b893beaa1b95ee86d51a613f208f9f4edf150 (diff)
downloadethos-u-vela-22f74e1c39572f084ad05cc2f208446fd2f50138.tar.gz
MLBEDSW-2383 Preserve previous metadata
The input tflite file potentially has metadata attached to it, which was lost when writing the vela optimised tflite file out. This patch preserves any metadata found. Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: I7b4e941696d21b81802fd4398cd405323778bedf
-rw-r--r--ethosu/vela/nn_graph.py2
-rw-r--r--ethosu/vela/tflite_reader.py8
-rw-r--r--ethosu/vela/tflite_writer.py10
3 files changed, 15 insertions, 5 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 4a2855b2..bfab2270 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -502,7 +502,7 @@ class Graph:
self.name = name
self.batch_size = batch_size
self.subgraphs = []
-
+ self.metadata = []
self.memory_used = {}
self.bits_per_element = {}
self.total_size = {}
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 9346b760..bf3fe950 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -245,6 +245,14 @@ class TFLiteGraph:
sg.output_tensors = tflite_sg.outputs
self.nng.subgraphs.append(sg)
+ # Preserve the original metadata
+ for idx in range(model.MetadataLength()):
+ meta = model.Metadata(idx)
+ name = meta.Name()
+ if name is not None:
+ buf_data = self.buffers[meta.Buffer()]
+ self.nng.metadata.append((name, buf_data))
+
def parse_buffer(self, buf_data):
if buf_data.DataLength() == 0:
return None
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 3f3b7b1b..92b5c6b0 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -414,12 +414,14 @@ class TFLiteSerialiser:
if tens.mem_type in set((MemType.Scratch, MemType.Scratch_fast)) and tens.address is not None:
offsets[idx] = np.int32(tens.address)
- metadata_buffer = np.array([version, subgraph_idx, nbr_tensors] + offsets)
- self.buffers_to_write.append(metadata_buffer)
+ self.nng.metadata.append(("OfflineMemoryAllocation", np.array([version, subgraph_idx, nbr_tensors] + offsets)))
- buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write])
+ metadata_list = []
+ for name, buffer in self.nng.metadata:
+ self.buffers_to_write.append(buffer)
+ metadata_list.append((name, len(self.buffers_to_write) - 1))
- metadata_list = [("OfflineMemoryAllocation", len(self.buffers_to_write) - 1)]
+ buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write])
metadata_offset = self.write_offset_vector([self.serialise_metadata(metadata) for metadata in metadata_list])
Model.ModelStart(builder)