diff options
-rw-r--r-- | ethosu/vela/scheduler.py | 20 | ||||
-rw-r--r-- | ethosu/vela/tensor.py | 2 |
2 files changed, 10 insertions, 12 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index de10bad7..f96b7732 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -824,7 +824,7 @@ class Scheduler: ] # Propose different striping - the possible stripes are proposed similarly to a binary search - best_schedule = buffered_sub_schedule + best_schedule = None iteration = 0 while len(possible_stripes) > 1: proposed_stripe = possible_stripes[len(possible_stripes) // 2] @@ -860,18 +860,16 @@ class Scheduler: # Maximum performance schedule fits within the SRAM target return max_sched - # Extract the cascades - cascades = schedule.cascades - # Remove existing cascade from schedule - schedule.cascades = {} - for cost in schedule.cost_map.values(): - cost.cascade = 0 - for cascade_info in cascades.values(): + # Iterate over a copy of the cascades since they may change during the loop + for cascade_info in list(schedule.cascades.values()): # Optimize the sub-schedule in this cascade opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit) - # Update the sub-schedule Op and cascade costs to the full schedule - schedule.cost_map.update(opt_sub_schedule.cost_map) - schedule.cascades.update(opt_sub_schedule.cascades) + if opt_sub_schedule: + # Remove the existing cascade + del schedule.cascades[cascade_info.end] + # Update the sub-schedule Op and cascade costs to the full schedule + schedule.cost_map.update(opt_sub_schedule.cost_map) + schedule.cascades.update(opt_sub_schedule.cascades) # Update memory snapshot self.sg.schedule = schedule diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 8304a657..37fd06ea 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -535,7 +535,7 @@ class Tensor: assert param_a is not None shp[-1] = min(shp[-1], param_a * 2) else: - shp = list(self.storage_shape) + shp = full_shape(4, self.storage_shape, 1) if sub_purpose == TensorSubPurpose.RollingBufferX: assert len(shp) == 4 assert param_a is not None |