aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/cascade_builder.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r--ethosu/vela/cascade_builder.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 0e651b90..95872cfe 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -188,9 +188,6 @@ class CascadeBuilder:
# The first IFM needs to be stored in full
cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
- # Add non-local memory usage
- cascade_ifm_size += self.non_local_mem_usage.get(op, 0)
-
# Sum of all intermediate cascade buffers (including weight buffers)
cascade_buffers = weight_buffer
# Best cascade size - Initially it's the fallback cost of the first Op in the cascade
@@ -248,8 +245,10 @@ class CascadeBuilder:
best_cascade_size = cascade_buffers
else:
- # Calculate the total size of the current cascade
- cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+ # Calculate the total size of the current cascade including non local mem usage
+ cascade_size = (
+ cascade_ifm_size + cascade_buffers + op_full_ofm + self.non_local_mem_usage.get(op, 0)
+ )
# Determine if cascading search should stop
if (
@@ -257,7 +256,8 @@ class CascadeBuilder:
and best_cascade_size < peak_sram_usage
or (cascade_ifm_size + cascade_buffers) > best_cascade_size
):
- # Both the existing cascade and current Op fits
+ # Both the existing cascade and current Op fits or
+ # not possible to reduce cascade size any further
break
"""
@@ -306,7 +306,7 @@ class CascadeBuilder:
hence, better to choose Cascade OP1-OP3 in this case.
"""
if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage:
- best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+ best_cascade_size = cascade_size
ops_in_best_cascade = [op for op in ops_in_cascade]
producer = current_op
@@ -326,9 +326,15 @@ class CascadeBuilder:
prev_op = cascaded_op
- # Create a CascadeInfo for the cascade
+ # Create a CascadeInfo for the cascade, only store the actual size used by
+ # the cascade so non local usage is removed. This is done in order to be
+ # able to calculate the correct non local usage in the scheduler when
+ # optimizing the sub schedules.
cascade_map[cascade_end] = CascadeInfo(
- cascade_start, cascade_end, buffers_in_cascade, best_cascade_size
+ cascade_start,
+ cascade_end,
+ buffers_in_cascade,
+ best_cascade_size - self.non_local_mem_usage.get(op, 0),
)
if not self.spilling:
# Update peak memory usage