diff options
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r-- | ethosu/vela/tflite_writer.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 8db3e5b8..7e805e31 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -22,7 +22,7 @@ from flatbuffers import encode from flatbuffers.builder import UOffsetTFlags from .nn_graph import PassPlacement -from .tensor import MemArea +from .tensor import MemType from .tensor import TensorPurpose from .tflite import Buffer from .tflite import Metadata @@ -74,6 +74,7 @@ class TFLiteSerialiser: self.nng = nng self.scratch_buf_id = 0 # Always assign scratch to buffer 0 + self.scratch_fast_buf_id = 1 # Always assign scratch_fast to buffer 1 self.buffer_offsets_map = {} self.buffers_to_write = [] # have an empty array there @@ -140,11 +141,16 @@ class TFLiteSerialiser: scratch_tensor_mem_area = None # all tensors are initialised to MemArea.Unknown buffer_map = {} + buf_idx = 1 for tens in tensors: - if tens.mem_area == scratch_tensor_mem_area: + # Set buffer ids depending on allocation + if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area): buffer_map[tens] = self.scratch_buf_id + elif tens.mem_type == MemType.Scratch_fast: + # For Scratch_fast when not co-allocated with scratch in the TensorArena: + buffer_map[tens] = self.scratch_fast_buf_id else: buffer_map[tens] = buf_idx buf_idx += 1 @@ -229,11 +235,9 @@ class TFLiteSerialiser: if tens.purpose == TensorPurpose.Scratch: tens_shape = [0] - self.buffers_to_write[self.scratch_buf_id] = values.flatten().view(np.uint8) buf_id = self.buffer_map[tens] - if buf_id != self.scratch_buf_id: - self.buffers_to_write[buf_id] = values.flatten().view(np.uint8) + self.buffers_to_write[buf_id] = values.flatten().view(np.uint8) shape = self.write_int_vector(tens_shape) @@ -396,7 +400,8 @@ class TFLiteSerialiser: # Ensure that the order of the offsets match the order of the tensors for tens, idx in self.tensor_map.items(): - if tens.mem_area == MemArea.Sram: + # Set offsets for tensor allocated in Tensor Arena or in the scratch_fast area + if tens.mem_type in set((MemType.Scratch, MemType.Scratch_fast)): offsets[idx] = np.int32(tens.address) metadata_buffer = np.array([version, subgraph_idx, nbr_tensors] + offsets) |