diff options
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r-- | ethosu/vela/tflite_writer.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 4aa23b5f..cf40b5b5 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -142,7 +142,7 @@ class TFLiteSerialiser: buffer_map = {} - buf_idx = 1 + buf_idx = 2 for tens in tensors: # Set buffer ids depending on allocation @@ -314,7 +314,11 @@ class TFLiteSerialiser: all_tensors = [tens for nm, idx, tens in sorted((tens.name, idx, tens) for idx, tens in enumerate(tensor_set))] - scratch_tensors = [tens for tens in all_tensors if tens.purpose == TensorPurpose.Scratch] + scratch_tensors = [tens for tens in all_tensors if tens.name.endswith("scratch")] + + for tens in all_tensors: + if tens.name.endswith("scratch_fast"): + scratch_fast_tensor = tens if len(scratch_tensors) == 0: scratch_tensor = None @@ -331,11 +335,16 @@ class TFLiteSerialiser: assert all(inp in sg.original_inputs for inp in sg.input_tensors) inputs = [self.tensor_map[tens] for tens in sg.original_inputs] - # Add the Scratch Tensor as input to the NPU subgraph to get it allocated by TensorFlow Lite Micro + # Add the Scratch Tensors as input to the NPU subgraph to get them allocated by TensorFlow Lite Micro scratch_tensor_idx = self.tensor_map.get(scratch_tensor, None) + scratch_fast_tensor_idx = self.tensor_map.get(scratch_fast_tensor, None) + if scratch_tensor_idx is not None and scratch_tensor_idx not in inputs: inputs.append(scratch_tensor_idx) + if scratch_fast_tensor_idx is not None and scratch_fast_tensor_idx not in inputs: + inputs.append(scratch_fast_tensor_idx) + inputs_offset = self.write_int_vector(inputs) outputs_offset = self.write_int_vector([self.tensor_map[tens] for tens in sg.output_tensors]) |