aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/live_range.py
blob: 0c8215daf01a72579fd2a4ba0591ca1b9514ab5f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Description:
# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
# Can work with either a pass packed subgraph or a scheduled subgraph.
from collections import namedtuple
from typing import List

import numpy as np

from .operation import Op
from .tensor import MemArea
from .tensor import MemType
from .tensor import Tensor
from .tensor import TensorPurpose
from .utils import progress_print


class LiveRange:
    def __init__(self, tens, alignment):
        self.tensors = []  # Tensors that are assigned to the same LiveRange will be allocated to the same address
        self.start_time = 99999999999
        self.end_time = -1
        self.size = 0
        self.name = ""
        self.alignment = alignment
        self.mem_area = tens.mem_area if tens else MemArea.Unknown

        if tens:
            self.add_tensor(tens)

    def __str__(self):
        return (
            f"<live_range.LiveRange: {self.start_time}-{self.end_time}, "
            f"size={self.size}, '{self.name}' #:{len(self.tensors)}>"
        )

    __repr__ = __str__

    def add_tensor(self, tens):
        if self.size == 0:
            self.size = tens.storage_size()
            self.name = tens.name  # LiveRange will be named after the first tensor added
        else:
            assert (
                self.size >= tens.storage_size()
            ), "Tensors assigned to the same LiveRange need to fit the size of the LiveRange."

        self.tensors.append(tens)

    def mark_usage(self, op_time, op_length=1):
        op_time_start = max(op_time, 0)
        op_time_end = op_time + op_length
        if op_time_end < op_time_start:
            return

        self.start_time = min(self.start_time, op_time_start)
        self.end_time = max(self.end_time, op_time_end)

    def set_buffer_size(self, buffer_size):
        self.size = buffer_size
        self.mem_area = MemArea.Sram

    def overlaps_ranges(self, other):
        return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)

    def overlaps_address(self, other):
        # Returns the first pair of tensors in this LiveRange and 'other' which have
        # overlapping addresses
        for tens in self.tensors:
            for other_tens in other.tensors:
                if max(tens.address, other_tens.address) < min(
                    tens.address + self.size, other_tens.address + other.size
                ):
                    return True, tens, other_tens

        return False, None, None

    def __lt__(self, other):
        if self.start_time != other.start_time:
            return self.start_time < other.start_time
        if self.end_time != other.end_time:
            return self.end_time < other.end_time
        if self.size != other.size:
            return self.size < other.size
        return self.name < other.name

    def set_address(self, address):
        # Set address of all tensors in LiveRange
        for tens in self.tensors:
            tens.address = address

        return address

    def get_alignment(self):
        return self.alignment

    def set_alignment(self, alignment):
        self.alignment = max(self.alignment, alignment)


class LiveRangeGraph:
    def __init__(self):
        self.lrs: List[LiveRange] = []  # List of all created ranges
        self.ranges = {}  # tens -> range
        self.processed_subgraphs = set()
        self.current_time = 0
        self.end_time = None

    def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
        # Return the live range of the tensor (or any of its clones)
        for existing_tensor, rng in self.ranges.items():
            if tens.equivalent(existing_tensor):
                rng.set_alignment(alignment)
                return rng

        # No live range found for the tensor, create a new one
        rng = LiveRange(tens, alignment)
        self.ranges[tens] = rng
        self.lrs.append(rng)
        return rng

    def fuse_ranges(self, in_tens, out_tens):
        live_range = self.get_or_create_range(in_tens)
        assert out_tens not in self.ranges, out_tens
        live_range.add_tensor(out_tens)
        self.ranges[out_tens] = live_range
        return live_range

    def get_endtime(self):
        # op_length is 1 so max end time for lr is current + 1
        return self.current_time + 1

    def get_temporal_memory_usage(self, target_mem_area):
        usage = np.zeros(self.get_endtime() + 1, dtype=np.int32)
        for lr in self.lrs:
            if lr.mem_area == target_mem_area:
                # End time is inclusive
                assert lr.end_time <= self.get_endtime() + 1
                usage[lr.start_time : lr.end_time + 1] += lr.size

        return usage


def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
    if tens.purpose == TensorPurpose.Virtual:
        return True
    if target_mem_area is None or target_mem_type_set is None:
        return False
    if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
        return True
    return False


def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
    ifm_tens = None
    elem_op = sched_op.parent_op
    if sched_op.op_type.is_elementwise_op() and elem_op.memory_function is not Op.VariableTensorWrite:
        # Check if possible to merge ifm/ofm live ranges of elementwise op
        if not tensor_should_be_ignored(elem_op.ofm, target_mem_area, target_mem_type_set):
            # Check if overwriting the inputs can be allowed
            OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
            outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
            inps = []
            if elem_op.ifm is not None:
                inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
            if elem_op.ifm2 is not None:
                inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
            # find an input tensor that can be overwritten by the output
            for inp in inps:
                if (
                    # check op input and output shapes allow overlapping
                    inp.op_shape == outp.op_shape
                    # check input tensor is valid
                    and inp.tens is not None
                    and inp.tens.shape != []
                    and not inp.tens.ifm_write_protected
                    and not tensor_should_be_ignored(inp.tens, target_mem_area, target_mem_type_set)
                    # check input and output tensors are compatible
                    and inp.tens.format == outp.tens.format
                    and inp.tens.dtype == outp.tens.dtype
                    # check input tensor only has one consumer
                    and len(inp.tens.consumer_list) == 1
                    # check output tensor only has one producer
                    and len(outp.tens.ops) == 1
                ):
                    ifm_tens = inp.tens
                    break
    elif sched_op.op_type == Op.Memcpy:
        # Check if possible to merge ifm/ofm live ranges of dma op
        dma_op = sched_op.parent_op
        ifm = dma_op.ifm
        ofm = dma_op.ofm
        if not (
            tensor_should_be_ignored(ifm, target_mem_area, target_mem_type_set)
            or tensor_should_be_ignored(ofm, target_mem_area, target_mem_type_set)
            # input tensor only allowed to have one consumer
            or len(ifm.consumer_list) > 1
        ):
            # Currently DMA only used when bypassing memory only ops so ok to reuse ifm
            # if ifm has only one consumer
            ifm_tens = ifm

    return ifm_tens


def ofm_can_reuse_ifm(sched_op, target_mem_area=None, target_mem_type_set=None):
    ifm = _get_ifm_to_fuse(sched_op, target_mem_area, target_mem_type_set)
    return ifm is not None


def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set):
    ifm = _get_ifm_to_fuse(sched_op, target_mem_area, target_mem_type_set)
    if ifm:
        lr_graph.fuse_ranges(ifm, sched_op.parent_op.ofm)


def extract_live_ranges_from_cascaded_passes(
    sg,
    target_mem_area,
    target_mem_type_set,
    lr_graph=None,
    cpu_tensor_alignment=Tensor.AllocationQuantum,
    verbose_progress: bool = False,
):
    if lr_graph is None:
        lr_graph = LiveRangeGraph()

    if sg in lr_graph.processed_subgraphs:
        # if subgraph has been processed already, return the lr_graph as is
        return lr_graph

    for index, cps in enumerate(sg.cascaded_passes):
        progress_print(verbose_progress, "Processing cascaded pass", index, sg.cascaded_passes)
        cps.time = lr_graph.current_time

        time_for_pass = cps.time

        for tens in cps.inputs:
            if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
                continue
            rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
            rng.mark_usage(time_for_pass)

        op = cps.passes[0].ops[0] if cps.passes[0].ops else None
        op_subgraph = op.attrs.get("subgraph", None) if op else None

        if op_subgraph is not None and MemType.Permanent_CPU not in target_mem_type_set:
            if op.type == Op.CustomNpuOp:
                # If the primary-op is an NpuOp that means this is where an Npu subgraph
                # is called. Go into said subgraph and extract live ranges before continuing.
                # Use default allocation alignment of 16 for Npu tensors
                lr_graph = extract_live_ranges_from_schedule(
                    op_subgraph, target_mem_area, target_mem_type_set, lr_graph
                )
            else:
                # The op has one or more subgraphs in it (a typical op is the While op)
                # Go into all subgraphs and extract live ranges before continuing.
                for op_sg in op_subgraph:
                    lr_graph = extract_live_ranges_from_cascaded_passes(
                        op_sg, target_mem_area, target_mem_type_set, lr_graph, cpu_tensor_alignment
                    )
            # Set the new time after handling the Npu subgraph
            # current time is updated in subgraph path so do not tick the time
            time_for_pass = lr_graph.current_time
            cps.time = time_for_pass
        else:
            lr_graph.current_time += 2

        for tens in cps.intermediates + cps.outputs:
            if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
                continue
            rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
            rng.mark_usage(time_for_pass)

    time_to_set = lr_graph.current_time
    for tens in sg.output_tensors:
        if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
            continue
        rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
        rng.mark_usage(time_to_set)

    # Variable tensor live-range is for entire inference
    for tens, rng in lr_graph.ranges.items():
        if tens.is_variable:
            rng.mark_usage(0, time_to_set + 1)

    # Add subgraph to set of processed subgraphs
    lr_graph.processed_subgraphs.add(sg)
    return lr_graph


def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_graph):
    assert lr_graph is not None
    sg_time = lr_graph.current_time
    for ps in sg.passes:
        for tens in ps.inputs + ps.outputs + ps.intermediates:
            if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
                tens, target_mem_area, target_mem_type_set
            ):
                continue
            rng = lr_graph.get_or_create_range(tens)
            rng.mark_usage(sg_time)

    for _, op_info in sg.schedule.cost_map.items():
        for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]:
            if tensor and not (tensor_should_be_ignored(tensor, target_mem_area, target_mem_type_set)):
                rng = lr_graph.get_or_create_range(tensor)
                rng.mark_usage(sg_time)

    lr_graph.current_time += 1
    return lr_graph


def extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph, verbose_progress=False):
    time_for_cascade = {}
    for index, sched_op in enumerate(sg.sched_ops):
        progress_print(verbose_progress, "Processing SchedulerOp", index, sg.sched_ops)
        op_info = sg.schedule.cost_map[sched_op]
        cascade = op_info.cascade
        cascade_info = sg.schedule.cascades.get(cascade, None)

        if cascade_info is None:
            # Op is not part of a cascade, check if the ifm can be overwritten by the ofm
            merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set)

        time_to_set = time_for_cascade.get(cascade, lr_graph.current_time)

        op_info.time_index = time_to_set

        # Mark usage for all tensors related to this Pass
        ps = sched_op.parent_ps
        for tens in ps.inputs + ps.outputs + ps.intermediates:
            if (
                target_mem_area == MemArea.Sram
                and cascade_info
                and tens == ps.ifm_tensor
                and sched_op in cascade_info.buffers
            ):
                # This tensor is a rolling buffer in a cascade and the size of the LiveRange needs to be modified
                # for enabling temporal memory snapshots without modifying the original Tensor
                rng = lr_graph.get_or_create_range(tens)
                rng.set_buffer_size(cascade_info.buffers[sched_op].elements() * sched_op.ifm.dtype.size_in_bytes())
            elif (
                tens.purpose == TensorPurpose.Weights
                or tens.purpose == TensorPurpose.FSBias
                or tens.mem_type not in target_mem_type_set
                or tens.mem_area != target_mem_area
            ):
                continue

            else:
                rng = lr_graph.get_or_create_range(tens)

            rng.mark_usage(time_to_set)

        for idx, weight_tens in enumerate(op_info.buffered_weight_tensors):
            if weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
                rng = lr_graph.get_or_create_range(weight_tens)
                start_time = time_to_set
                length = 1
                if weight_tens.pre_buffer:
                    start_time -= 1
                    length += 1
                if len(op_info.buffered_weight_tensors) > 1:
                    last_idx = len(op_info.ofm_depth_slices) % len(op_info.buffered_weight_tensors)
                    # Double buffering: reduce end time of the buffer that is not used last
                    if last_idx != idx:
                        length -= 1
                rng.mark_usage(start_time, length)

        if time_to_set == lr_graph.current_time:
            lr_graph.current_time += 2

        if cascade != 0:
            time_for_cascade[cascade] = time_to_set

    time_to_set = lr_graph.current_time
    for tens in sg.output_tensors:
        if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area:
            continue
        rng = lr_graph.get_or_create_range(tens)
        rng.mark_usage(time_to_set)

    return lr_graph