aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/cascade_builder.py24
-rw-r--r--ethosu/vela/scheduler.py11
2 files changed, 21 insertions, 14 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 0e651b90..95872cfe 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -188,9 +188,6 @@ class CascadeBuilder:
# The first IFM needs to be stored in full
cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
- # Add non-local memory usage
- cascade_ifm_size += self.non_local_mem_usage.get(op, 0)
-
# Sum of all intermediate cascade buffers (including weight buffers)
cascade_buffers = weight_buffer
# Best cascade size - Initially it's the fallback cost of the first Op in the cascade
@@ -248,8 +245,10 @@ class CascadeBuilder:
best_cascade_size = cascade_buffers
else:
- # Calculate the total size of the current cascade
- cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+ # Calculate the total size of the current cascade including non local mem usage
+ cascade_size = (
+ cascade_ifm_size + cascade_buffers + op_full_ofm + self.non_local_mem_usage.get(op, 0)
+ )
# Determine if cascading search should stop
if (
@@ -257,7 +256,8 @@ class CascadeBuilder:
and best_cascade_size < peak_sram_usage
or (cascade_ifm_size + cascade_buffers) > best_cascade_size
):
- # Both the existing cascade and current Op fits
+ # Both the existing cascade and current Op fits or
+ # not possible to reduce cascade size any further
break
"""
@@ -306,7 +306,7 @@ class CascadeBuilder:
hence, better to choose Cascade OP1-OP3 in this case.
"""
if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage:
- best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
+ best_cascade_size = cascade_size
ops_in_best_cascade = [op for op in ops_in_cascade]
producer = current_op
@@ -326,9 +326,15 @@ class CascadeBuilder:
prev_op = cascaded_op
- # Create a CascadeInfo for the cascade
+ # Create a CascadeInfo for the cascade, only store the actual size used by
+ # the cascade so non local usage is removed. This is done in order to be
+ # able to calculate the correct non local usage in the scheduler when
+ # optimizing the sub schedules.
cascade_map[cascade_end] = CascadeInfo(
- cascade_start, cascade_end, buffers_in_cascade, best_cascade_size
+ cascade_start,
+ cascade_end,
+ buffers_in_cascade,
+ best_cascade_size - self.non_local_mem_usage.get(op, 0),
)
if not self.spilling:
# Update peak memory usage
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 16531c2c..83e19bc6 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -952,8 +952,7 @@ class Scheduler:
if cost[sched_op].cascade:
# This Op is part of a cascade - use the cascade's memory usage
cascade_info = cascades[cost[sched_op].cascade]
- # Non-local memory usage is already included in the cascade_info
- peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage)
+ op_mem_usage = cascade_info.mem_usage + non_local_mem_usage.get(sched_op, 0)
else:
# This Op is not part of a cascade - calculate the memory usage
op_weight_buffer = sum(tens.storage_size() for tens in cost[sched_op].buffered_weight_tensors)
@@ -964,7 +963,7 @@ class Scheduler:
+ op_weight_buffer
+ non_local_mem_usage.get(sched_op, 0)
)
- peak_mem_usage = max(op_mem_usage, peak_mem_usage)
+ peak_mem_usage = max(op_mem_usage, peak_mem_usage)
return peak_mem_usage
@@ -1021,9 +1020,11 @@ class Scheduler:
time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index
mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage
# If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's
- # included in a cascade or not
+ # included in a cascade or not. Not valid in Dedicated SRAM mode (spilling enabled).
persistent_initial_ifm = (
- sub_schedule_ops[0].ifm_size_in_bytes() if len(sub_schedule_ops[0].ifm.connection.consumers) > 1 else 0
+ sub_schedule_ops[0].ifm_size_in_bytes()
+ if not self.arch.is_spilling_enabled() and len(sub_schedule_ops[0].ifm.connection.consumers) > 1
+ else 0
)
# Calculate non-local-mem-usage per Operator
non_local_mem_usage = {}