aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/cascade_builder.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2022-03-01 11:26:58 +0100
committerFredrik Svedberg <fredrik.svedberg@arm.com>2022-03-30 13:00:15 +0000
commitcc5f4de1c35ba44fca7ff6295c6ae846f8242344 (patch)
tree68c4f8124a3ee6ec6f7fceb32a1d8aec11ac9a86 /ethosu/vela/cascade_builder.py
parenta19b4671dd0594181a2789930cc98bf5dc41ded4 (diff)
downloadethos-u-vela-cc5f4de1c35ba44fca7ff6295c6ae846f8242344.tar.gz
MLBEDSW-6263: Use separate tensors for double buffering
Uses separate tensors for the individual weight buffers in case of weight double buffering. Each weight buffer tensor gets its own individual live range. Change-Id: I724a8c61a7045615fbd2ed9535663076ac8edd13 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r--ethosu/vela/cascade_builder.py12
1 files changed, 4 insertions, 8 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 4c3f75b7..0d25ec64 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -144,10 +144,8 @@ class CascadeBuilder:
# Keep track of which Ops are in the proposed cascade as well as the best cascade so far
ops_in_cascade = [op]
ops_in_best_cascade = [op]
- # Get the size of the weight buffer
- weight_buffer = 0
- if ref_cost[op].buffered_weight_tensor:
- weight_buffer = ref_cost[op].buffered_weight_tensor.storage_size()
+ # Get the size of the weight buffer(s)
+ weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
# The first IFM needs to be stored in full
cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
@@ -190,10 +188,8 @@ class CascadeBuilder:
op_full_ofm = current_op.ofm_size_in_bytes()
_, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
- # Get the size of the weight buffer
- op_weight_buffer = 0
- if ref_cost[current_op].buffered_weight_tensor:
- op_weight_buffer = ref_cost[current_op].buffered_weight_tensor.storage_size()
+ # Get the size of the weight buffer(s)
+ op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
# Calculate the uncascaded memory requirement for current Op
uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)