aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/cascade_builder.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r--ethosu/vela/cascade_builder.py12
1 files changed, 4 insertions, 8 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 4c3f75b7..0d25ec64 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -144,10 +144,8 @@ class CascadeBuilder:
# Keep track of which Ops are in the proposed cascade as well as the best cascade so far
ops_in_cascade = [op]
ops_in_best_cascade = [op]
- # Get the size of the weight buffer
- weight_buffer = 0
- if ref_cost[op].buffered_weight_tensor:
- weight_buffer = ref_cost[op].buffered_weight_tensor.storage_size()
+ # Get the size of the weight buffer(s)
+ weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
# The first IFM needs to be stored in full
cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
@@ -190,10 +188,8 @@ class CascadeBuilder:
op_full_ofm = current_op.ofm_size_in_bytes()
_, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
- # Get the size of the weight buffer
- op_weight_buffer = 0
- if ref_cost[current_op].buffered_weight_tensor:
- op_weight_buffer = ref_cost[current_op].buffered_weight_tensor.storage_size()
+ # Get the size of the weight buffer(s)
+ op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
# Calculate the uncascaded memory requirement for current Op
uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)