aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/range_set.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/range_set.py')
-rw-r--r--ethosu/vela/range_set.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/ethosu/vela/range_set.py b/ethosu/vela/range_set.py
new file mode 100644
index 00000000..64de9709
--- /dev/null
+++ b/ethosu/vela/range_set.py
@@ -0,0 +1,154 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Description:
+# Helper classes to track memory accesses for calculating dependencies between Commands.
+
+from enum import IntEnum
+from collections import defaultdict
+from functools import lru_cache
+
+
+class RangeSet:
+ """A Range set class to track ranges and whether they intersect.
+Intended for e.g. tracking sets of memory ranges and whether two commands use the same memory areas."""
+
+ def __init__(self, start=None, end=None, ranges=None):
+ if ranges is None:
+ ranges = []
+
+ self.ranges = ranges # track a list of (start, end) tuples, always in ascending order sorted by start.
+
+ if start is not None and start != end:
+ assert start < end
+ self.ranges.append((start, end))
+
+ def __or__(self, other):
+ combined_ranges = list(sorted(self.ranges + other.ranges))
+ return RangeSet(ranges=combined_ranges)
+
+ def __ior__(self, other):
+ self.ranges = list(sorted(self.ranges + other.ranges))
+ return self
+
+ def intersects(self, other):
+ a_ranges = self.ranges
+ b_ranges = other.ranges
+
+ a_idx = 0
+ b_idx = 0
+
+ while a_idx < len(a_ranges) and b_idx < len(b_ranges):
+ ar = a_ranges[a_idx]
+ br = b_ranges[b_idx]
+ if max(ar[0], br[0]) < min(ar[1], br[1]):
+ return True # intersection
+
+ # advance one of the two upwards
+ if ar[0] < br[0]:
+ a_idx += 1
+ else:
+ assert ar[0] != br[0]
+ # note ar[0] == br[0] cannot happen, then we'd have an intersection
+ b_idx += 1
+
+ return False
+
+ def __str__(self):
+ return "<RangeSet %s>" % (["%#x:%#x" % (int(start), int(end)) for start, end in self.ranges],)
+
+ __repr__ = __str__
+
+
+class MemoryRangeSet:
+ """Extended version of the RangeSet class that handles having different memory areas"""
+
+ def __init__(self, mem_area=None, start=None, end=None, regions=None):
+
+ if regions is None:
+ regions = {}
+ self.regions = regions
+
+ if mem_area is not None:
+ self.regions[mem_area] = RangeSet(start, end)
+
+ def __or__(self, other):
+ combined_regions = {
+ mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
+ for mem_area in (self.regions.keys() | other.regions.keys())
+ }
+ return MemoryRangeSet(regions=combined_regions)
+
+ def __ior__(self, other):
+ self.regions = {
+ mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
+ for mem_area in (self.regions.keys() | other.regions.keys())
+ }
+ return self
+
+ def intersects(self, other):
+ for mem_area in self.regions.keys() & other.regions.keys():
+ if self.regions[mem_area].intersects(other.regions[mem_area]):
+ return True
+ return False
+
+ def __str__(self):
+ s = "<MemoryRangeSet>"
+ for mem_area, rng in self.regions.items():
+ s += "%s: %s\t" % (mem_area, rng)
+ return s
+
+ __repr__ = __str__
+
+
+class AccessDirection(IntEnum):
+ Read = 0
+ Write = 1
+ Size = 2
+
+
+class MemoryAccessSet:
+ """Tracks memory ranges, but also access patterns to know which accesses actually are in conflict"""
+
+ def __init__(self):
+ self.accesses = [MemoryRangeSet() for i in range(AccessDirection.Size)]
+
+ def add(self, memory_range_set, access):
+ self.accesses[access] |= memory_range_set
+
+ @lru_cache(maxsize=None)
+ def conflicts(self, other):
+
+ # True dependencies, or write -> read
+ if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Read]):
+ return True
+
+ # Anti-dependencies, or read -> write
+ if self.accesses[AccessDirection.Read].intersects(other.accesses[AccessDirection.Write]):
+ return True
+
+ # Output dependencies, or write -> write
+ if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Write]):
+ return True
+
+ # read -> read does not cause a conflict
+ return False
+
+ def __str__(self):
+ return "Read: %s\nWrite: %s\n\n" % (self.accesses[AccessDirection.Read], self.accesses[AccessDirection.Write])
+
+ __repr__ = __str__