diff options
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r-- | ethosu/vela/cascade_builder.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py index 0e651b90..95872cfe 100644 --- a/ethosu/vela/cascade_builder.py +++ b/ethosu/vela/cascade_builder.py @@ -188,9 +188,6 @@ class CascadeBuilder: # The first IFM needs to be stored in full cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0 - # Add non-local memory usage - cascade_ifm_size += self.non_local_mem_usage.get(op, 0) - # Sum of all intermediate cascade buffers (including weight buffers) cascade_buffers = weight_buffer # Best cascade size - Initially it's the fallback cost of the first Op in the cascade @@ -248,8 +245,10 @@ class CascadeBuilder: best_cascade_size = cascade_buffers else: - # Calculate the total size of the current cascade - cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm + # Calculate the total size of the current cascade including non local mem usage + cascade_size = ( + cascade_ifm_size + cascade_buffers + op_full_ofm + self.non_local_mem_usage.get(op, 0) + ) # Determine if cascading search should stop if ( @@ -257,7 +256,8 @@ class CascadeBuilder: and best_cascade_size < peak_sram_usage or (cascade_ifm_size + cascade_buffers) > best_cascade_size ): - # Both the existing cascade and current Op fits + # Both the existing cascade and current Op fits or + # not possible to reduce cascade size any further break """ @@ -306,7 +306,7 @@ class CascadeBuilder: hence, better to choose Cascade OP1-OP3 in this case. """ if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage: - best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm + best_cascade_size = cascade_size ops_in_best_cascade = [op for op in ops_in_cascade] producer = current_op @@ -326,9 +326,15 @@ class CascadeBuilder: prev_op = cascaded_op - # Create a CascadeInfo for the cascade + # Create a CascadeInfo for the cascade, only store the actual size used by + # the cascade so non local usage is removed. This is done in order to be + # able to calculate the correct non local usage in the scheduler when + # optimizing the sub schedules. cascade_map[cascade_end] = CascadeInfo( - cascade_start, cascade_end, buffers_in_cascade, best_cascade_size + cascade_start, + cascade_end, + buffers_in_cascade, + best_cascade_size - self.non_local_mem_usage.get(op, 0), ) if not self.spilling: # Update peak memory usage |