diff options
Diffstat (limited to 'ethosu/vela/compiler_driver.py')
-rw-r--r-- | ethosu/vela/compiler_driver.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py index 05bf65a4..1d7521b1 100644 --- a/ethosu/vela/compiler_driver.py +++ b/ethosu/vela/compiler_driver.py @@ -291,6 +291,12 @@ def compiler_driver(nng, arch, options, scheduler_options): npu_serialisation.rewrite_npu_call_ops(nng, root_sg, arch) + # Set Scratch and Fast_scratch Tensor size + if scratch_tens is not None: + scratch_tens.set_all_shapes([root_sg.memory_used_per_type.get(MemType.Scratch, 0)]) + if scratch_fast_tens is not None: + scratch_fast_tens.set_all_shapes([root_sg.memory_used_per_type.get(MemType.Scratch_fast, 0)]) + # Allocate all Cpu constant tensors, this is done last because the Npu-ops # have to be serialized into flash and scratch tensors first tensor_allocation.allocate_tensors( |