diff options
author | Rickard Bolin <rickard.bolin@arm.com> | 2022-05-16 09:11:06 +0000 |
---|---|---|
committer | Rickard Bolin <rickard.bolin@arm.com> | 2022-05-16 15:20:20 +0000 |
commit | fd8b500085d1ac1cca54a71631d21713a3c21f09 (patch) | |
tree | 4a8d1c7809dc1eb748f0f0b9ba2736e5d7bb5e69 /ethosu/vela/live_range.py | |
parent | 6f4cb0362a2f00b3045565de2c27f72997b2998b (diff) | |
download | ethos-u-vela-fd8b500085d1ac1cca54a71631d21713a3c21f09.tar.gz |
MLBEDSW-6263: Use separate tensors for double buffering
Uses separate tensors for the individual weight buffers
in case of weight double buffering.
Each weight buffer tensor gets its own individual live range.
This patch is a clone of a previously reverted patch, but with some
additional bug fixes applied.
Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I868c70d15821eb9f1399186f2da6e7345f6ee343
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r-- | ethosu/vela/live_range.py | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 19d0c11f..ccf49297 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -63,7 +63,7 @@ class LiveRange: def mark_usage(self, op_time, op_length=1): op_time_start = max(op_time, 0) op_time_end = op_time + op_length - if op_time_end <= op_time_start: + if op_time_end < op_time_start: return self.start_time = min(self.start_time, op_time_start) @@ -325,13 +325,20 @@ def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, rng.mark_usage(time_to_set) - weight_tens = op_info.buffered_weight_tensor - if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area: - rng = lr_graph.get_or_create_range(weight_tens) - if weight_tens.pre_buffer: - rng.mark_usage(time_to_set - 1, 2) - else: - rng.mark_usage(time_to_set) + for idx, weight_tens in enumerate(op_info.buffered_weight_tensors): + if weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area: + rng = lr_graph.get_or_create_range(weight_tens) + start_time = time_to_set + length = 1 + if weight_tens.pre_buffer: + start_time -= 1 + length += 1 + if len(op_info.buffered_weight_tensors) > 1: + last_idx = len(op_info.ofm_depth_slices) % len(op_info.buffered_weight_tensors) + # Double buffering: reduce end time of the buffer that is not used last + if last_idx != idx: + length -= 1 + rng.mark_usage(start_time, length) if time_to_set == lr_graph.current_time: lr_graph.current_time += 2 |