diff options
-rw-r--r-- | ethosu/vela/tflite_writer.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index c8250c6e..2e7345ce 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -75,9 +75,10 @@ def make_vector(v): class TFLiteSerialiser: - BUF_IDX_SCRATCH = 0 # Always assign scratch to buffer 0 - BUF_IDX_SCRATCH_FAST = 1 # Always assign scratch_fast to buffer 1 - BUF_IDX_START = 2 # Unique buffer id for every tensor in all subgraphs + # The 0th buffer is always by default an empty buffer that can be used by tensors + # without any constant data + BUF_IDX_ZERO = 0 + BUF_IDX_START = 1 def __init__(self, nng): self.builder = flatbuffers.Builder(0) @@ -161,11 +162,10 @@ class TFLiteSerialiser: for tens in tensors: # Set buffer ids depending on allocation - if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area): - buffer_map[tens] = TFLiteSerialiser.BUF_IDX_SCRATCH - elif tens.mem_type == MemType.Scratch_fast: - # For Scratch_fast when not co-allocated with scratch in the TensorArena: - buffer_map[tens] = TFLiteSerialiser.BUF_IDX_SCRATCH_FAST + if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area) or tens.mem_type == MemType.Scratch_fast: + # Tensor allocated in the scratch areas, does not have any constant data and can + # therefore all point to the empty buffer (zero) + buffer_map[tens] = TFLiteSerialiser.BUF_IDX_ZERO else: buffer_map[tens] = self.buf_idx self.buf_idx += 1 @@ -256,6 +256,8 @@ class TFLiteSerialiser: values = tens.values buf_id = self.buffer_map[tens] + # Sanity check that if buffer 0 is used there must not be any data + assert not (buf_id == TFLiteSerialiser.BUF_IDX_ZERO and values is not None) self.buffers_to_write[buf_id] = None if values is None else values.flatten().view(np.uint8) shape = self.write_int_vector(tens_shape) |