1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
|
# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Description:
# Functionality for lookup table support.
import uuid
from functools import lru_cache
import numpy as np
from . import numeric_util
from .high_level_command_stream import CommandType
from .tensor import create_const_tensor
from .tensor import TensorPurpose
@lru_cache(maxsize=None)
def create_equivalence_id(key):
# Generates equivalence_id based on key.
# The DMA optimization of LUT-s assumes that 2 LUT tensors are identical
# if they have the same equivalence_id.
# So for example all created 256-byte tanh LUT tensors should have
# the same equivalence id.
return uuid.uuid4()
class LUTState:
# Tracks which LUT-s are located in SHRAM.
def __init__(self):
self.tensors = []
def get_equivalent(self, lut_tens):
# Returns existing lut with same equivalence id, None if not found
for t in self.tensors:
if t.equivalent(lut_tens):
return t
return None
def put(self, lut_tens):
# Returns new LUT state containing given tensor + all tensors in this state
# that do not overlap with the given tensor
new_state = LUTState()
new_state.tensors.append(lut_tens)
start = lut_tens.address
end = start + lut_tens.storage_size()
for tens in self.tensors:
start2 = tens.address
end2 = start2 + tens.storage_size()
if not numeric_util.overlaps(start, end, start2, end2):
new_state.tensors.append(tens)
return new_state
def find_best_address(self, start, stop, step):
# Finds the address in the given range that overlaps with the minimum number of
# currently present LUT-s.
# An improvement would be to also take future LUT usage into account
best_addr = start
best_nr_overlaps = stop
for addr in range(start, stop, step):
nr_overlaps = 0
for tens in self.tensors:
start2 = tens.address
end2 = start2 + tens.storage_size()
if numeric_util.overlaps(addr, addr + step, start2, end2):
nr_overlaps += 1
if nr_overlaps < best_nr_overlaps:
best_nr_overlaps = nr_overlaps
best_addr = addr
return best_addr
def get_lut_index(arch, lut_tensor):
# Returns the index in SHRAM where the given LUT is stored, a value between 0 and 8
slot = (lut_tensor.address - arch.shram_lut_address) // lut_tensor.storage_size()
assert 0 <= slot < 8
return slot
def create_lut_tensor(name, values, dtype):
# Creates constant LUT tensor with the given values as lookup table.
# The tensor's equivalence_id is based on these values, so if multiple
# LUT tensors are created with identical values, they will get the same
# address in constant memory, and unnecessary DMA operations can be avoided.
sz = len(values)
assert sz in (256, 512)
ntype = np.uint8 if dtype.size_in_bytes() == 1 else np.uint32
tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, ntype, TensorPurpose.LUT)
tens.equivalence_id = create_equivalence_id(tuple(values))
return tens
def optimize_high_level_cmd_stream(sg, arch):
# - Allocates SHRAM address/lut index to LUT tensors
# - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream
cmd_stream = [] # will contain existing command stream minus unneeded DMA operations
lut_state = LUTState()
slot_size = 256
lut_start = arch.shram_lut_address
lut_end = lut_start + arch.shram_lut_size
for cmd in sg.high_level_command_stream:
if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
# The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA
# TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not
lut_state = LUTState()
if cmd.cmdtype != CommandType.DMA or cmd.out_tensor.purpose != TensorPurpose.LUT:
# Non-LUT operation; leave untouched
cmd_stream.append(cmd)
continue
# LUT DMA operation
lut_tens = cmd.out_tensor
existing_tens = lut_state.get_equivalent(lut_tens)
if existing_tens is not None:
# LUT is already in SHRAM, no need to perform DMA
lut_tens.address = existing_tens.address
cmd.ps.primary_op.attrs["lut_index"] = get_lut_index(arch, existing_tens)
continue
# Place the LUT in the last 2 blocks of SHRAM
# Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())
lut_tens.address = address
cmd.ps.primary_op.attrs["lut_index"] = (address - lut_start) // slot_size
lut_state = lut_state.put(lut_tens)
cmd_stream.append(cmd)
sg.high_level_command_stream = cmd_stream
|