aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-05-27 09:15:11 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2020-06-25 11:42:56 +0200
commiteca2e95e1fea150d8a942f8b5f0a4d9d7aefebc1 (patch)
tree438b385f1ded3c18c3b84d2204a57c39be6be34a /ethosu/vela/tflite_writer.py
parenteec4e50e19cb5522640eae5fd4566917dc2a7b9d (diff)
downloadethos-u-vela-eca2e95e1fea150d8a942f8b5f0a4d9d7aefebc1.tar.gz
MLBEDSW-2306 Added more supported mem-cfgs
Additional supported memory configurations: -Permanent_storage = DRAM -Tensor arena either in DRAM or SRAM Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I20beb7151e306bfdba540e7c0b2a7b478b4d94e1
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 8db3e5b8..7e805e31 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -22,7 +22,7 @@ from flatbuffers import encode
from flatbuffers.builder import UOffsetTFlags
from .nn_graph import PassPlacement
-from .tensor import MemArea
+from .tensor import MemType
from .tensor import TensorPurpose
from .tflite import Buffer
from .tflite import Metadata
@@ -74,6 +74,7 @@ class TFLiteSerialiser:
self.nng = nng
self.scratch_buf_id = 0 # Always assign scratch to buffer 0
+ self.scratch_fast_buf_id = 1 # Always assign scratch_fast to buffer 1
self.buffer_offsets_map = {}
self.buffers_to_write = [] # have an empty array there
@@ -140,11 +141,16 @@ class TFLiteSerialiser:
scratch_tensor_mem_area = None # all tensors are initialised to MemArea.Unknown
buffer_map = {}
+
buf_idx = 1
for tens in tensors:
- if tens.mem_area == scratch_tensor_mem_area:
+ # Set buffer ids depending on allocation
+ if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area):
buffer_map[tens] = self.scratch_buf_id
+ elif tens.mem_type == MemType.Scratch_fast:
+ # For Scratch_fast when not co-allocated with scratch in the TensorArena:
+ buffer_map[tens] = self.scratch_fast_buf_id
else:
buffer_map[tens] = buf_idx
buf_idx += 1
@@ -229,11 +235,9 @@ class TFLiteSerialiser:
if tens.purpose == TensorPurpose.Scratch:
tens_shape = [0]
- self.buffers_to_write[self.scratch_buf_id] = values.flatten().view(np.uint8)
buf_id = self.buffer_map[tens]
- if buf_id != self.scratch_buf_id:
- self.buffers_to_write[buf_id] = values.flatten().view(np.uint8)
+ self.buffers_to_write[buf_id] = values.flatten().view(np.uint8)
shape = self.write_int_vector(tens_shape)
@@ -396,7 +400,8 @@ class TFLiteSerialiser:
# Ensure that the order of the offsets match the order of the tensors
for tens, idx in self.tensor_map.items():
- if tens.mem_area == MemArea.Sram:
+ # Set offsets for tensor allocated in Tensor Arena or in the scratch_fast area
+ if tens.mem_type in set((MemType.Scratch, MemType.Scratch_fast)):
offsets[idx] = np.int32(tens.address)
metadata_buffer = np.array([version, subgraph_idx, nbr_tensors] + offsets)