aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/live_range.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2022-03-01 11:26:58 +0100
committerFredrik Svedberg <fredrik.svedberg@arm.com>2022-03-30 13:00:15 +0000
commitcc5f4de1c35ba44fca7ff6295c6ae846f8242344 (patch)
tree68c4f8124a3ee6ec6f7fceb32a1d8aec11ac9a86 /ethosu/vela/live_range.py
parenta19b4671dd0594181a2789930cc98bf5dc41ded4 (diff)
downloadethos-u-vela-cc5f4de1c35ba44fca7ff6295c6ae846f8242344.tar.gz
MLBEDSW-6263: Use separate tensors for double buffering
Uses separate tensors for the individual weight buffers in case of weight double buffering. Each weight buffer tensor gets its own individual live range. Change-Id: I724a8c61a7045615fbd2ed9535663076ac8edd13 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r--ethosu/vela/live_range.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index fc94e9dd..45baf440 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -63,7 +63,7 @@ class LiveRange:
def mark_usage(self, op_time, op_length=1):
op_time_start = max(op_time, 0)
op_time_end = op_time + op_length
- if op_time_end <= op_time_start:
+ if op_time_end < op_time_start:
return
self.start_time = min(self.start_time, op_time_start)
@@ -321,13 +321,20 @@ def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set,
rng.mark_usage(time_to_set)
- weight_tens = op_info.buffered_weight_tensor
- if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
- rng = lr_graph.get_or_create_range(weight_tens)
- if weight_tens.pre_buffer:
- rng.mark_usage(time_to_set - 1, 2)
- else:
- rng.mark_usage(time_to_set)
+ for idx, weight_tens in enumerate(op_info.buffered_weight_tensors):
+ if weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
+ rng = lr_graph.get_or_create_range(weight_tens)
+ start_time = time_to_set
+ length = 1
+ if weight_tens.pre_buffer:
+ start_time -= 1
+ length += 1
+ if len(op_info.buffered_weight_tensors) > 1:
+ last_idx = len(op_info.ofm_depth_slices) % len(op_info.buffered_weight_tensors)
+ # Double buffering: reduce end time of the buffer that is not used last
+ if last_idx != idx:
+ length -= 1
+ rng.mark_usage(start_time, length)
if time_to_set == lr_graph.current_time:
lr_graph.current_time += 2