aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/scheduler.py20
-rw-r--r--ethosu/vela/tensor.py2
2 files changed, 10 insertions, 12 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index de10bad7..f96b7732 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -824,7 +824,7 @@ class Scheduler:
]
# Propose different striping - the possible stripes are proposed similarly to a binary search
- best_schedule = buffered_sub_schedule
+ best_schedule = None
iteration = 0
while len(possible_stripes) > 1:
proposed_stripe = possible_stripes[len(possible_stripes) // 2]
@@ -860,18 +860,16 @@ class Scheduler:
# Maximum performance schedule fits within the SRAM target
return max_sched
- # Extract the cascades
- cascades = schedule.cascades
- # Remove existing cascade from schedule
- schedule.cascades = {}
- for cost in schedule.cost_map.values():
- cost.cascade = 0
- for cascade_info in cascades.values():
+ # Iterate over a copy of the cascades since they may change during the loop
+ for cascade_info in list(schedule.cascades.values()):
# Optimize the sub-schedule in this cascade
opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit)
- # Update the sub-schedule Op and cascade costs to the full schedule
- schedule.cost_map.update(opt_sub_schedule.cost_map)
- schedule.cascades.update(opt_sub_schedule.cascades)
+ if opt_sub_schedule:
+ # Remove the existing cascade
+ del schedule.cascades[cascade_info.end]
+ # Update the sub-schedule Op and cascade costs to the full schedule
+ schedule.cost_map.update(opt_sub_schedule.cost_map)
+ schedule.cascades.update(opt_sub_schedule.cascades)
# Update memory snapshot
self.sg.schedule = schedule
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 8304a657..37fd06ea 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -535,7 +535,7 @@ class Tensor:
assert param_a is not None
shp[-1] = min(shp[-1], param_a * 2)
else:
- shp = list(self.storage_shape)
+ shp = full_shape(4, self.storage_shape, 1)
if sub_purpose == TensorSubPurpose.RollingBufferX:
assert len(shp) == 4
assert param_a is not None