diff options
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/nn_graph.py | 2 | ||||
-rw-r--r-- | ethosu/vela/tflite_reader.py | 8 | ||||
-rw-r--r-- | ethosu/vela/tflite_writer.py | 10 |
3 files changed, 15 insertions, 5 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 4a2855b2..bfab2270 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -502,7 +502,7 @@ class Graph: self.name = name self.batch_size = batch_size self.subgraphs = [] - + self.metadata = [] self.memory_used = {} self.bits_per_element = {} self.total_size = {} diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 9346b760..bf3fe950 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -245,6 +245,14 @@ class TFLiteGraph: sg.output_tensors = tflite_sg.outputs self.nng.subgraphs.append(sg) + # Preserve the original metadata + for idx in range(model.MetadataLength()): + meta = model.Metadata(idx) + name = meta.Name() + if name is not None: + buf_data = self.buffers[meta.Buffer()] + self.nng.metadata.append((name, buf_data)) + def parse_buffer(self, buf_data): if buf_data.DataLength() == 0: return None diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 3f3b7b1b..92b5c6b0 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -414,12 +414,14 @@ class TFLiteSerialiser: if tens.mem_type in set((MemType.Scratch, MemType.Scratch_fast)) and tens.address is not None: offsets[idx] = np.int32(tens.address) - metadata_buffer = np.array([version, subgraph_idx, nbr_tensors] + offsets) - self.buffers_to_write.append(metadata_buffer) + self.nng.metadata.append(("OfflineMemoryAllocation", np.array([version, subgraph_idx, nbr_tensors] + offsets))) - buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write]) + metadata_list = [] + for name, buffer in self.nng.metadata: + self.buffers_to_write.append(buffer) + metadata_list.append((name, len(self.buffers_to_write) - 1)) - metadata_list = [("OfflineMemoryAllocation", len(self.buffers_to_write) - 1)] + buffers_offset = self.write_offset_vector([self.serialise_buffer(buf) for buf in self.buffers_to_write]) metadata_offset = self.write_offset_vector([self.serialise_metadata(metadata) for metadata in metadata_list]) Model.ModelStart(builder) |