aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/scheduler.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r--ethosu/vela/scheduler.py49
1 files changed, 26 insertions, 23 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 526cc0e9..4af83a10 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -1030,7 +1030,7 @@ class DynamicProgrammingScheduler:
for cp in sg.cascaded_passes:
if cp.strategy == SchedulingStrategy.IfmStream:
for ps in cp.passes:
- if ps.scale_tensor and (cp.sram_used + ps.scale_tensor.storage_size()) <= self.sram_limit:
+ if ps.scale_tensor:
tens = ps.scale_tensor
# Find op using scale tensor
@@ -1041,28 +1041,31 @@ class DynamicProgrammingScheduler:
new_tens = tens.clone_into_fast_storage(arch)
new_tens.consumer_list = tens.consumer_list.copy()
new_tens.purpose = TensorPurpose.FSBias
-
- # Create DMA cmd
- dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma")
- dma_cmd.inputs = [tens]
- dma_cmd.set_output_tensor(new_tens)
- dma_cmd.attrs["source"] = tens.mem_area
- dma_cmd.attrs["destination"] = new_tens.mem_area
- dma_cmd.run_on_npu = True
-
- tens.consumer_list.clear()
- tens.consumer_list.append(dma_cmd)
-
- # Replace tensor and op
- idx = op.inputs.index(tens)
- op.inputs[idx] = new_tens
-
- ps.ops.insert(0, dma_cmd)
- ps.scale_tensor = new_tens
- ps.intermediates.append(new_tens)
- ps.cascade.intermediates.append(new_tens)
-
- cp.sram_used += tens.storage_size()
+ new_tens.element_size_bytes = 10
+ new_tens_size = new_tens.storage_size()
+
+ if (cp.sram_used + new_tens_size) <= self.sram_limit:
+ # Create DMA cmd
+ dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma")
+ dma_cmd.inputs = [tens]
+ dma_cmd.set_output_tensor(new_tens)
+ dma_cmd.attrs["source"] = tens.mem_area
+ dma_cmd.attrs["destination"] = new_tens.mem_area
+ dma_cmd.run_on_npu = True
+
+ tens.consumer_list.clear()
+ tens.consumer_list.append(dma_cmd)
+
+ # Replace tensor and op
+ idx = op.inputs.index(tens)
+ op.inputs[idx] = new_tens
+
+ ps.ops.insert(0, dma_cmd)
+ ps.scale_tensor = new_tens
+ ps.intermediates.append(new_tens)
+ ps.cascade.intermediates.append(new_tens)
+
+ cp.sram_used += new_tens_size
def schedule_passes(nng, arch, options: SchedulerOptions):