diff options
author | Michael McGeagh <michael.mcgeagh@arm.com> | 2020-08-07 16:21:03 +0100 |
---|---|---|
committer | Fredrik Knutsson <fredrik.knutsson.hunnebo@gmail.com> | 2020-08-12 06:25:23 +0000 |
commit | 22f74e1c39572f084ad05cc2f208446fd2f50138 (patch) | |
tree | 96d602818de364e2297fb67c2f00fc1245ba74f7 /ethosu | |
parent | e99b893beaa1b95ee86d51a613f208f9f4edf150 (diff) | |
download | ethos-u-vela-22f74e1c39572f084ad05cc2f208446fd2f50138.tar.gz |
MLBEDSW-2383 Preserve previous metadata
The input tflite file potentially has metadata attached to it, which was
lost when writing the vela optimised tflite file out.
This patch preserves any metadata found.
Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: I7b4e941696d21b81802fd4398cd405323778bedf
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/nn_graph.py | 2 | ||||
-rw-r--r-- | ethosu/vela/tflite_reader.py | 8 | ||||
-rw-r--r-- | ethosu/vela/tflite_writer.py | 10 |
3 files changed, 15 insertions, 5 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 4a2855b2..bfab2270 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -502,7 +502,7 @@ class Graph: self.name = name self.batch_size = batch_size self.subgraphs = [] - + self.metadata = [] self.memory_used = {} self.bits_per_element = {} self.total_size = {} diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 9346b760..bf3fe950 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -245,6 +245,14 @@ class TFLiteGraph: sg.output_tensors = tflite_sg.outputs self.nng.subgraphs.append(sg) + # Preserve the original metadata + for idx in range(model.MetadataLength()): + meta = model.Metadata(idx) + name = meta.Name() + if name is not None: + buf_data = self.buffers[meta.Buffer()] + self.nng.metadata.append((name, buf_data)) + def parse_buffer(self, buf_data): if buf_data.DataLength() == 0: return None diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 3f3b7b1b..92b5c6b0 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -414,12 +414,14 @@ class TFLiteSerialiser: if tens.mem_type in set((MemType.Scratch, MemType.Scratch_fast)) and tens.address is not None: offsets[idx] = np.int32(tens.address) - metadata_buffer = np.array([version, subgraph_idx, nbr_tensors] + offsets) - self.buffers_to_write.append(metadata_buffer) + self.nng.metadata.append(("OfflineMemoryAllocation", np.array([version, subgraph_idx, nbr_tensors] + offsets))) - buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write]) + metadata_list = [] + for name, buffer in self.nng.metadata: + self.buffers_to_write.append(buffer) + metadata_list.append((name, len(self.buffers_to_write) - 1)) - metadata_list = [("OfflineMemoryAllocation", len(self.buffers_to_write) - 1)] + buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write]) metadata_offset = self.write_offset_vector([self.serialise_metadata(metadata) for metadata in metadata_list]) Model.ModelStart(builder) |