diff options
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/greedy_allocation.py | 4 | ||||
-rw-r--r-- | ethosu/vela/live_range.py | 18 |
2 files changed, 14 insertions, 8 deletions
diff --git a/ethosu/vela/greedy_allocation.py b/ethosu/vela/greedy_allocation.py index d6896a59..8393434d 100644 --- a/ethosu/vela/greedy_allocation.py +++ b/ethosu/vela/greedy_allocation.py @@ -46,8 +46,8 @@ class GreedyAllocator: current_offset = start_addr + lr.size + best_offset = new_lr.set_address(best_offset) self.memory_required = max(self.memory_required, best_offset + size) - new_lr.set_address(best_offset) self.current_allocs.append((best_offset, new_lr)) self.current_allocs = list(sorted(self.current_allocs)) @@ -77,7 +77,7 @@ class GreedyAllocator: for m in lrs: if n != m and n.overlaps_ranges(m): overlap, tens_n, tens_m = n.overlaps_address(m) - if overlap: + if overlap and not (tens_n.equivalence_id == tens_m.equivalence_id and tens_n.address == tens_m.address): print("Solution failed, overlapping buffer!") print(tens_n.address, tens_n.address + n.size, n.name) print(tens_m.address, tens_m.address + m.size, m.name) diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 23ab67d9..2a35a119 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -86,12 +86,18 @@ class LiveRange: # Set address of all unaddressed tensors in LiveRange for tens in self.tensors: if tens.address == 0: - tens.address = address - # Also need to set the address to the tensor's cpu/npu clones - if tens.cpu_tensor is not None: - tens.cpu_tensor.address = address - if tens.npu_tensor is not None: - tens.npu_tensor.address = address + addr = address + else: + # Limit to single tensor for the lr if the tensor address already assigned + assert len(self.tensors) == 1 + addr = tens.address + tens.address = addr + # Also need to set the address to the tensor's cpu/npu clones + if tens.cpu_tensor is not None: + tens.cpu_tensor.address = addr + if tens.npu_tensor is not None: + tens.npu_tensor.address = addr + return addr def get_alignment(self): # Get max alignment of LiveRange's tensors |