From 9bfe0f86ea525055954b160a3c678024743030ec Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Thu, 3 Dec 2020 12:26:25 +0100 Subject: MLBEDSW-1373: Added search based allocator Added a new tensor allocator that is based on searching, implemented in C++ (C++11 compatible). Change-Id: Ie96e9fcfc8e6c58d1fa53911f37de290eeba88cf Signed-off-by: Louis Verhaard --- ethosu/vela/greedy_allocation.py | 26 --------------------- ethosu/vela/nn_graph.py | 1 + ethosu/vela/tensor_allocation.py | 50 +++++++++++++++++++++++++++++++++++++++- ethosu/vela/vela.py | 2 +- 4 files changed, 51 insertions(+), 28 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/greedy_allocation.py b/ethosu/vela/greedy_allocation.py index 58d948c2..b0395def 100644 --- a/ethosu/vela/greedy_allocation.py +++ b/ethosu/vela/greedy_allocation.py @@ -16,7 +16,6 @@ # Description: # Allocate tensor addresses using a greedy algorithm. from . import numeric_util -from .errors import AllocationError class GreedyAllocator: @@ -70,33 +69,8 @@ class GreedyAllocator: self.alloc(new_lr) - self.verify_allocation(alignment) return self.memory_required - def verify_allocation(self, alignment): - lrs = list(self.live_ranges.ranges.values()) - for n in lrs: - for tens in n.tensors: - if not all(op and op.run_on_npu for op in tens.ops + tens.consumer_list): - # This is a CPU tensor, verify alignment - if tens.address % alignment != 0: - raise AllocationError("Tensor {} not aligned to {} bytes".format(tens.name, alignment)) - - for m in lrs: - if n != m and n.overlaps_ranges(m): - overlap, tens_n, tens_m = n.overlaps_address(m) - if overlap and not (tens_n.equivalent(tens_m) and tens_n.address == tens_m.address): - raise AllocationError( - "Overlapping buffers: {}: {} -> {} and {}: {} -> {}".format( - n.name, - tens_n.address, - tens_n.address + n.size, - m.name, - tens_m.address, - tens_m.address + m.size, - ) - ) - def allocate_live_ranges(nng, arch, live_ranges, mem_area, alignment, verbose_allocation=False): g = GreedyAllocator(nng, arch, live_ranges, mem_area) diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index b2877851..0ae3de9a 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -36,6 +36,7 @@ class PassPlacement(enum.Enum): class TensorAllocator(enum.Enum): LinearAlloc = 1 Greedy = 2 + Search = 3 def __str__(self): return self.name diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py index d1a33728..9736ca22 100644 --- a/ethosu/vela/tensor_allocation.py +++ b/ethosu/vela/tensor_allocation.py @@ -24,11 +24,13 @@ from . import live_range from . import numeric_util from .errors import AllocationError from .greedy_allocation import allocate_live_ranges as greedy_allocate_live_ranges +from .live_range import LiveRangeGraph from .nn_graph import TensorAllocator from .tensor import MemArea from .tensor import MemType from .tensor import Tensor from .tensor import TensorPurpose +from ethosu import tensor_allocator def linear_allocate_live_ranges(live_ranges, alloc_granularity=Tensor.AllocationQuantum): @@ -61,7 +63,29 @@ def linear_allocate_live_ranges(live_ranges, alloc_granularity=Tensor.Allocation return total_sz -def verify_alignment(live_ranges, alignment): +def search_allocate_live_ranges(live_ranges: LiveRangeGraph, alloc_granularity: int) -> int: + # Allocates using the search-based allocator (implemented in C++) + input = [] + lrs = [] + lr_set = set() + for lr in live_ranges.ranges.values(): + lr_set.add((lr.start_time, lr.end_time, lr)) + lr_list = sorted(lr_set) + # Create a single array of ints containing start/end/size of the live ranges + for start, end, lr in lr_list: + input += [start, end, numeric_util.round_up(lr.size, alloc_granularity)] + lrs.append(lr) + addresses = tensor_allocator.allocate(input, 0) + # The result is a list containing the allocated addresses + total_sz = 0 + for lr, address in zip(lrs, addresses): + total_sz = max(total_sz, address + lr.size) + lr.set_address(address) + verify_allocation(live_ranges, alloc_granularity) + return total_sz + + +def verify_alignment(live_ranges: LiveRangeGraph, alignment: int): for lr in live_ranges.ranges.values(): for tens in lr.tensors: if not all(op and op.run_on_npu for op in tens.ops + tens.consumer_list): @@ -70,6 +94,27 @@ def verify_alignment(live_ranges, alignment): raise AllocationError("Tensor {} not aligned to {} bytes".format(tens.name, alignment)) +def verify_allocation(live_ranges: LiveRangeGraph, alignment: int): + lrs = list(live_ranges.ranges.values()) + for n in lrs: + verify_alignment(live_ranges, alignment) + + for m in lrs: + if n != m and n.overlaps_ranges(m): + overlap, tens_n, tens_m = n.overlaps_address(m) + if overlap and not (tens_n.equivalent(tens_m) and tens_n.address == tens_m.address): + raise AllocationError( + "Overlapping buffers: {}: {} -> {} and {}: {} -> {}".format( + n.name, + tens_n.address, + tens_n.address + n.size, + m.name, + tens_m.address, + tens_m.address + m.size, + ) + ) + + def mark_sram_used_for_cascaded_passes(sg, lrs): end_pos = max(ps.time for ps in sg.cascaded_passes) + 2 mem_usage = np.zeros(end_pos, dtype=np.int64) @@ -137,8 +182,11 @@ def allocate_tensors( tens_alloc = tensor_allocator if tens_alloc == TensorAllocator.Greedy: total_sz = greedy_allocate_live_ranges(sg, arch, lrs, mem_area, cpu_tensor_alignment, verbose_allocation) + verify_allocation(lrs, cpu_tensor_alignment) elif tens_alloc == TensorAllocator.LinearAlloc: total_sz = linear_allocate_live_ranges(lrs, cpu_tensor_alignment) + elif tens_alloc == TensorAllocator.Search: + total_sz = search_allocate_live_ranges(lrs, cpu_tensor_alignment) else: assert 0 alloc_ok = max_size is None or total_sz <= max_size diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py index 37de1ed2..d27eef0e 100644 --- a/ethosu/vela/vela.py +++ b/ethosu/vela/vela.py @@ -282,7 +282,7 @@ def main(args=None): ) parser.add_argument( "--tensor-allocator", - default=TensorAllocator.Greedy, + default=TensorAllocator.Search, type=lambda s: TensorAllocator[s], choices=list(TensorAllocator), help="Tensor Allocator algorithm (default: %(default)s)", -- cgit v1.2.1