diff options
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r-- | ethosu/vela/cascade_builder.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py index 0d25ec64..4c3f75b7 100644 --- a/ethosu/vela/cascade_builder.py +++ b/ethosu/vela/cascade_builder.py @@ -144,8 +144,10 @@ class CascadeBuilder: # Keep track of which Ops are in the proposed cascade as well as the best cascade so far ops_in_cascade = [op] ops_in_best_cascade = [op] - # Get the size of the weight buffer(s) - weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors) + # Get the size of the weight buffer + weight_buffer = 0 + if ref_cost[op].buffered_weight_tensor: + weight_buffer = ref_cost[op].buffered_weight_tensor.storage_size() # The first IFM needs to be stored in full cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0 @@ -188,8 +190,10 @@ class CascadeBuilder: op_full_ofm = current_op.ofm_size_in_bytes() _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost) - # Get the size of the weight buffer(s) - op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors) + # Get the size of the weight buffer + op_weight_buffer = 0 + if ref_cost[current_op].buffered_weight_tensor: + op_weight_buffer = ref_cost[current_op].buffered_weight_tensor.storage_size() # Calculate the uncascaded memory requirement for current Op uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0) |