diff options
-rw-r--r-- | ethosu/vela/cascade_builder.py | 24 | ||||
-rw-r--r-- | ethosu/vela/scheduler.py | 11 |
2 files changed, 21 insertions, 14 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py index 0e651b90..95872cfe 100644 --- a/ethosu/vela/cascade_builder.py +++ b/ethosu/vela/cascade_builder.py @@ -188,9 +188,6 @@ class CascadeBuilder: # The first IFM needs to be stored in full cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0 - # Add non-local memory usage - cascade_ifm_size += self.non_local_mem_usage.get(op, 0) - # Sum of all intermediate cascade buffers (including weight buffers) cascade_buffers = weight_buffer # Best cascade size - Initially it's the fallback cost of the first Op in the cascade @@ -248,8 +245,10 @@ class CascadeBuilder: best_cascade_size = cascade_buffers else: - # Calculate the total size of the current cascade - cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm + # Calculate the total size of the current cascade including non local mem usage + cascade_size = ( + cascade_ifm_size + cascade_buffers + op_full_ofm + self.non_local_mem_usage.get(op, 0) + ) # Determine if cascading search should stop if ( @@ -257,7 +256,8 @@ class CascadeBuilder: and best_cascade_size < peak_sram_usage or (cascade_ifm_size + cascade_buffers) > best_cascade_size ): - # Both the existing cascade and current Op fits + # Both the existing cascade and current Op fits or + # not possible to reduce cascade size any further break """ @@ -306,7 +306,7 @@ class CascadeBuilder: hence, better to choose Cascade OP1-OP3 in this case. """ if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage: - best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm + best_cascade_size = cascade_size ops_in_best_cascade = [op for op in ops_in_cascade] producer = current_op @@ -326,9 +326,15 @@ class CascadeBuilder: prev_op = cascaded_op - # Create a CascadeInfo for the cascade + # Create a CascadeInfo for the cascade, only store the actual size used by + # the cascade so non local usage is removed. This is done in order to be + # able to calculate the correct non local usage in the scheduler when + # optimizing the sub schedules. cascade_map[cascade_end] = CascadeInfo( - cascade_start, cascade_end, buffers_in_cascade, best_cascade_size + cascade_start, + cascade_end, + buffers_in_cascade, + best_cascade_size - self.non_local_mem_usage.get(op, 0), ) if not self.spilling: # Update peak memory usage diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 16531c2c..83e19bc6 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -952,8 +952,7 @@ class Scheduler: if cost[sched_op].cascade: # This Op is part of a cascade - use the cascade's memory usage cascade_info = cascades[cost[sched_op].cascade] - # Non-local memory usage is already included in the cascade_info - peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage) + op_mem_usage = cascade_info.mem_usage + non_local_mem_usage.get(sched_op, 0) else: # This Op is not part of a cascade - calculate the memory usage op_weight_buffer = sum(tens.storage_size() for tens in cost[sched_op].buffered_weight_tensors) @@ -964,7 +963,7 @@ class Scheduler: + op_weight_buffer + non_local_mem_usage.get(sched_op, 0) ) - peak_mem_usage = max(op_mem_usage, peak_mem_usage) + peak_mem_usage = max(op_mem_usage, peak_mem_usage) return peak_mem_usage @@ -1021,9 +1020,11 @@ class Scheduler: time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage # If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's - # included in a cascade or not + # included in a cascade or not. Not valid in Dedicated SRAM mode (spilling enabled). persistent_initial_ifm = ( - sub_schedule_ops[0].ifm_size_in_bytes() if len(sub_schedule_ops[0].ifm.connection.consumers) > 1 else 0 + sub_schedule_ops[0].ifm_size_in_bytes() + if not self.arch.is_spilling_enabled() and len(sub_schedule_ops[0].ifm.connection.consumers) > 1 + else 0 ) # Calculate non-local-mem-usage per Operator non_local_mem_usage = {} |