aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/live_range.py
diff options
context:
space:
mode:
authorRickard Bolin <rickard.bolin@arm.com>2022-05-16 09:11:06 +0000
committerRickard Bolin <rickard.bolin@arm.com>2022-05-16 15:20:20 +0000
commitfd8b500085d1ac1cca54a71631d21713a3c21f09 (patch)
tree4a8d1c7809dc1eb748f0f0b9ba2736e5d7bb5e69 /ethosu/vela/live_range.py
parent6f4cb0362a2f00b3045565de2c27f72997b2998b (diff)
downloadethos-u-vela-fd8b500085d1ac1cca54a71631d21713a3c21f09.tar.gz
MLBEDSW-6263: Use separate tensors for double buffering
Uses separate tensors for the individual weight buffers in case of weight double buffering. Each weight buffer tensor gets its own individual live range. This patch is a clone of a previously reverted patch, but with some additional bug fixes applied. Signed-off-by: Rickard Bolin <rickard.bolin@arm.com> Change-Id: I868c70d15821eb9f1399186f2da6e7345f6ee343
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r--ethosu/vela/live_range.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 19d0c11f..ccf49297 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -63,7 +63,7 @@ class LiveRange:
def mark_usage(self, op_time, op_length=1):
op_time_start = max(op_time, 0)
op_time_end = op_time + op_length
- if op_time_end <= op_time_start:
+ if op_time_end < op_time_start:
return
self.start_time = min(self.start_time, op_time_start)
@@ -325,13 +325,20 @@ def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set,
rng.mark_usage(time_to_set)
- weight_tens = op_info.buffered_weight_tensor
- if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
- rng = lr_graph.get_or_create_range(weight_tens)
- if weight_tens.pre_buffer:
- rng.mark_usage(time_to_set - 1, 2)
- else:
- rng.mark_usage(time_to_set)
+ for idx, weight_tens in enumerate(op_info.buffered_weight_tensors):
+ if weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
+ rng = lr_graph.get_or_create_range(weight_tens)
+ start_time = time_to_set
+ length = 1
+ if weight_tens.pre_buffer:
+ start_time -= 1
+ length += 1
+ if len(op_info.buffered_weight_tensors) > 1:
+ last_idx = len(op_info.ofm_depth_slices) % len(op_info.buffered_weight_tensors)
+ # Double buffering: reduce end time of the buffer that is not used last
+ if last_idx != idx:
+ length -= 1
+ rng.mark_usage(start_time, length)
if time_to_set == lr_graph.current_time:
lr_graph.current_time += 2