aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/tflite_writer.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index c8250c6e..2e7345ce 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -75,9 +75,10 @@ def make_vector(v):
class TFLiteSerialiser:
- BUF_IDX_SCRATCH = 0 # Always assign scratch to buffer 0
- BUF_IDX_SCRATCH_FAST = 1 # Always assign scratch_fast to buffer 1
- BUF_IDX_START = 2 # Unique buffer id for every tensor in all subgraphs
+ # The 0th buffer is always by default an empty buffer that can be used by tensors
+ # without any constant data
+ BUF_IDX_ZERO = 0
+ BUF_IDX_START = 1
def __init__(self, nng):
self.builder = flatbuffers.Builder(0)
@@ -161,11 +162,10 @@ class TFLiteSerialiser:
for tens in tensors:
# Set buffer ids depending on allocation
- if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area):
- buffer_map[tens] = TFLiteSerialiser.BUF_IDX_SCRATCH
- elif tens.mem_type == MemType.Scratch_fast:
- # For Scratch_fast when not co-allocated with scratch in the TensorArena:
- buffer_map[tens] = TFLiteSerialiser.BUF_IDX_SCRATCH_FAST
+ if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area) or tens.mem_type == MemType.Scratch_fast:
+ # Tensor allocated in the scratch areas, does not have any constant data and can
+ # therefore all point to the empty buffer (zero)
+ buffer_map[tens] = TFLiteSerialiser.BUF_IDX_ZERO
else:
buffer_map[tens] = self.buf_idx
self.buf_idx += 1
@@ -256,6 +256,8 @@ class TFLiteSerialiser:
values = tens.values
buf_id = self.buffer_map[tens]
+ # Sanity check that if buffer 0 is used there must not be any data
+ assert not (buf_id == TFLiteSerialiser.BUF_IDX_ZERO and values is not None)
self.buffers_to_write[buf_id] = None if values is None else values.flatten().view(np.uint8)
shape = self.write_int_vector(tens_shape)