diff options
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r-- | ethosu/vela/tflite_writer.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 99df849b..675b6985 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -134,12 +134,17 @@ class TFLiteSerialiser: return builder.EndVector(len(v)) def assign_buffers_to_tensors(self, tensors): + scratch_tensors = [tens for tens in tensors if tens.purpose == TensorPurpose.Scratch] + if len(scratch_tensors) > 0: + scratch_tensor_mem_area = scratch_tensors[0].mem_area + else: + scratch_tensor_mem_area = None # all tensors are initialised to MemArea.Unknown + buffer_map = {} - scratch_tensor = [tens for tens in tensors if tens.purpose == TensorPurpose.Scratch][0] buf_idx = 1 for tens in tensors: - if tens.mem_area == scratch_tensor.mem_area: + if tens.mem_area == scratch_tensor_mem_area: buffer_map[tens] = self.scratch_buf_id else: buffer_map[tens] = buf_idx |