1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
# SPDX-FileCopyrightText: Copyright 2020, 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Description:
# Helper classes to track memory accesses for calculating dependencies between Commands.
from enum import IntEnum
from functools import lru_cache
class RangeSet:
"""A Range set class to track ranges and whether they intersect.
Intended for e.g. tracking sets of memory ranges and whether two commands use the same memory areas."""
def __init__(self, start=None, end=None, ranges=None):
if ranges is None:
ranges = []
self.ranges = ranges # track a list of (start, end) tuples, always in ascending order sorted by start.
if start is not None and start != end:
self.ranges.append((start, end))
def __or__(self, other):
combined_ranges = list(sorted(self.ranges + other.ranges))
return RangeSet(ranges=combined_ranges)
def __ior__(self, other):
self.ranges = list(sorted(self.ranges + other.ranges))
return self
def intersects(self, other):
a_ranges = self.ranges
b_ranges = other.ranges
a_idx = 0
b_idx = 0
while a_idx < len(a_ranges) and b_idx < len(b_ranges):
ar = a_ranges[a_idx]
br = b_ranges[b_idx]
if max(ar[0], br[0]) < min(ar[1], br[1]):
return True # intersection
# advance one of the two upwards
if ar[0] < br[0]:
a_idx += 1
else:
assert ar[0] != br[0]
# note ar[0] == br[0] cannot happen, then we'd have an intersection
b_idx += 1
return False
def __str__(self):
return "<RangeSet %s>" % (["%#x:%#x" % (int(start), int(end)) for start, end in self.ranges],)
__repr__ = __str__
class MemoryRangeSet:
"""Extended version of the RangeSet class that handles having different memory areas"""
def __init__(self, mem_area=None, start=None, end=None, regions=None):
if regions is None:
regions = {}
self.regions = regions
if mem_area is not None:
self.regions[mem_area] = RangeSet(start, end)
def __or__(self, other):
combined_regions = {
mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
for mem_area in (self.regions.keys() | other.regions.keys())
}
return MemoryRangeSet(regions=combined_regions)
def __ior__(self, other):
self.regions = {
mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
for mem_area in (self.regions.keys() | other.regions.keys())
}
return self
def intersects(self, other):
for mem_area in self.regions.keys() & other.regions.keys():
if self.regions[mem_area].intersects(other.regions[mem_area]):
return True
return False
def __str__(self):
s = "<MemoryRangeSet>"
for mem_area, rng in self.regions.items():
s += "%s: %s\t" % (mem_area, rng)
return s
__repr__ = __str__
class AccessDirection(IntEnum):
Read = 0
Write = 1
Size = 2
class MemoryAccessSet:
"""Tracks memory ranges, but also access patterns to know which accesses actually are in conflict"""
def __init__(self):
self.accesses = [MemoryRangeSet() for i in range(AccessDirection.Size)]
def add(self, memory_range_set, access):
self.accesses[access] |= memory_range_set
@lru_cache(maxsize=None)
def conflicts(self, other):
# True dependencies, or write -> read
if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Read]):
return True
# Anti-dependencies, or read -> write
if self.accesses[AccessDirection.Read].intersects(other.accesses[AccessDirection.Write]):
return True
# Output dependencies, or write -> write
if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Write]):
return True
# read -> read does not cause a conflict
return False
def __str__(self):
return "Read: %s\nWrite: %s\n\n" % (self.accesses[AccessDirection.Read], self.accesses[AccessDirection.Write])
__repr__ = __str__
|