aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor_allocation.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor_allocation.py')
-rw-r--r--ethosu/vela/tensor_allocation.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index d53babc3..1efcd686 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -128,7 +128,10 @@ def allocate_tensors(
show_minimum_possible_allocation=False,
lr_graph=None,
allocation_alignment=Tensor.AllocationQuantum,
+ max_size=None,
+ dry_test=False,
):
+ # Allocates addresses to tensors, returns False if tensors could not be fit within max_size
ignore_subgraph_input_output_tensors = False
lrs = live_range.extract_live_ranges_from_cascaded_passes(
sg,
@@ -149,6 +152,12 @@ def allocate_tensors(
total_sz = linear_allocate_live_ranges(lrs, allocation_alignment)
else:
assert 0
+ alloc_ok = max_size is None or total_sz <= max_size
+ if dry_test or not alloc_ok:
+ # Dry test or allocation failed; undo allocation
+ for lr in lrs.ranges.values():
+ lr.set_address(None)
+ return alloc_ok
if sg.memory_used.get(mem_area, 0) == 0:
sg.memory_used[mem_area] = total_sz
@@ -179,3 +188,4 @@ def allocate_tensors(
nng.bits_per_element[mem_area] = nng.total_size[mem_area] * 8 / nng.total_elements[mem_area]
except ZeroDivisionError:
nng.bits_per_element[mem_area] = 0.0
+ return True