aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/cascade_builder.py
diff options
context:
space:
mode:
authorRickard Bolin <rickard.bolin@arm.com>2022-05-16 09:11:06 +0000
committerRickard Bolin <rickard.bolin@arm.com>2022-05-16 15:20:20 +0000
commitfd8b500085d1ac1cca54a71631d21713a3c21f09 (patch)
tree4a8d1c7809dc1eb748f0f0b9ba2736e5d7bb5e69 /ethosu/vela/cascade_builder.py
parent6f4cb0362a2f00b3045565de2c27f72997b2998b (diff)
downloadethos-u-vela-fd8b500085d1ac1cca54a71631d21713a3c21f09.tar.gz
MLBEDSW-6263: Use separate tensors for double buffering
Uses separate tensors for the individual weight buffers in case of weight double buffering. Each weight buffer tensor gets its own individual live range. This patch is a clone of a previously reverted patch, but with some additional bug fixes applied. Signed-off-by: Rickard Bolin <rickard.bolin@arm.com> Change-Id: I868c70d15821eb9f1399186f2da6e7345f6ee343
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r--ethosu/vela/cascade_builder.py12
1 files changed, 4 insertions, 8 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 4703583b..e7105e2c 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -146,10 +146,8 @@ class CascadeBuilder:
# Keep track of which Ops are in the proposed cascade as well as the best cascade so far
ops_in_cascade = [op]
ops_in_best_cascade = [op]
- # Get the size of the weight buffer
- weight_buffer = 0
- if ref_cost[op].buffered_weight_tensor:
- weight_buffer = ref_cost[op].buffered_weight_tensor.storage_size()
+ # Get the size of the weight buffer(s)
+ weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
# The first IFM needs to be stored in full
cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
@@ -192,10 +190,8 @@ class CascadeBuilder:
op_full_ofm = current_op.ofm_size_in_bytes()
_, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
- # Get the size of the weight buffer
- op_weight_buffer = 0
- if ref_cost[current_op].buffered_weight_tensor:
- op_weight_buffer = ref_cost[current_op].buffered_weight_tensor.storage_size()
+ # Get the size of the weight buffer(s)
+ op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
# Calculate the uncascaded memory requirement for current Op
uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)