aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/lut.py
blob: 7b563b087f5f99ecd46b4e00aa90e29e24197d6f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
# SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Description:
# Functionality for lookup table support.
import uuid

import numpy as np

from . import fp_math
from . import numeric_util
from .data_type import DataType
from .debug_database import DebugDatabase
from .high_level_command_stream import DMA
from .high_level_command_stream import NpuStripe
from .numeric_util import round_away_zero
from .operation import Op
from .scaling import quantise_scale
from .tensor import create_const_tensor
from .tensor import create_equivalence_id
from .tensor import QuantizationParameters
from .tensor import TensorPurpose


class LUTState:
    # Tracks which LUT-s are located in SHRAM.
    def __init__(self):
        self.tensors = []

    def get_equivalent(self, lut_tens):
        # Returns existing lut with the same values, None if not found
        for t in self.tensors:
            if np.array_equal(t.values, lut_tens.values):
                return t
        return None

    def put(self, lut_tens):
        # Returns new LUT state containing given tensor + all tensors in this state
        # that do not overlap with the given tensor
        new_state = LUTState()
        new_state.tensors.append(lut_tens)
        start = lut_tens.address
        end = start + lut_tens.storage_size()
        for tens in self.tensors:
            start2 = tens.address
            end2 = start2 + tens.storage_size()
            if not numeric_util.overlaps(start, end, start2, end2):
                new_state.tensors.append(tens)

        return new_state

    def find_best_address(self, start, stop, step):
        # Finds the address in the given range that overlaps with the minimum number of
        # currently present LUT-s.
        # An improvement would be to also take future LUT usage into account
        best_addr = start
        best_nr_overlaps = stop
        for addr in range(start, stop, step):
            nr_overlaps = 0
            for tens in self.tensors:
                start2 = tens.address
                end2 = start2 + tens.storage_size()
                if numeric_util.overlaps(addr, addr + step, start2, end2):
                    nr_overlaps += 1
            if nr_overlaps < best_nr_overlaps:
                best_nr_overlaps = nr_overlaps
                best_addr = addr
        return best_addr


def get_lut_index(arch, lut_tensor):
    # Returns the index in SHRAM where the given LUT is stored, a value between 0 and 8
    slot = (lut_tensor.address - arch.shram_lut_address) // arch.shram_lut_slot_size
    assert 0 <= slot < 8
    return slot


def create_lut_tensor(name, values, dtype):
    # Creates constant LUT tensor with the given values as lookup table.
    # The tensor's equivalence_id is based on these values, so if multiple
    # LUT tensors are created with identical values, they will get the same
    # address in constant memory, and unnecessary DMA operations can be avoided.
    sz = len(values)
    assert sz in (256, 512)
    # int16 lut uses uint32 lut with base + slope
    dtype = DataType.uint32 if dtype == DataType.int16 else dtype
    tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, TensorPurpose.LUT)
    tens.equivalence_id = create_equivalence_id(tuple(values))
    return tens


def optimize_high_level_cmd_stream(sg, arch):
    # - Allocates SHRAM address/lut index to LUT tensors
    # - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream
    cmd_stream = []  # will contain existing command stream minus unneeded DMA operations
    lut_state = LUTState()
    lut_start = arch.shram_lut_address
    lut_end = lut_start + arch.shram_lut_size
    for cmd in sg.high_level_command_stream:
        if isinstance(cmd, NpuStripe) and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
            # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA
            # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not
            lut_state = LUTState()
        if not isinstance(cmd, DMA) or cmd.out_tensor.purpose != TensorPurpose.LUT:
            # Non-LUT operation; leave untouched
            cmd_stream.append(cmd)
            continue
        # LUT DMA operation
        lut_tens = cmd.out_tensor
        existing_tens = lut_state.get_equivalent(lut_tens)
        if existing_tens is not None:
            # LUT is already in SHRAM, no need to perform DMA
            lut_tens.equivalence_id = existing_tens.equivalence_id
            lut_tens.address = existing_tens.address
            cmd.ps.primary_op.activation.lut_index = get_lut_index(arch, existing_tens)
            continue
        # Place the LUT in the last 2 blocks of SHRAM
        # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
        address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())

        lut_tens.equivalence_id = uuid.uuid4()
        lut_tens.address = address
        cmd.ps.primary_op.activation.lut_index = (address - lut_start) // arch.shram_lut_slot_size
        lut_state = lut_state.put(lut_tens)
        cmd_stream.append(cmd)
    sg.high_level_command_stream = cmd_stream


def convert_to_lut(op, lut_values, lut_name):
    # Rewrite the operation by Add with scalar 0 + LUT activation
    ifm = op.ifm
    ofm = op.ofm
    if ifm is None:
        return op
    assert ifm.dtype in (DataType.int8, DataType.uint8, DataType.int16)
    op.type = Op.Add
    op.name = f"{op.name}_lut_{lut_name}"
    # Mark as no-op to enable potential fusing optimizations
    op.attrs["is_nop"] = True
    # Create an input tensor containing scalar zero
    _max = 65536.0 if ifm.dtype == DataType.int16 else 255.0
    quantization = QuantizationParameters(0.0, _max)
    quantization.scale_f32 = ifm.quantization.scale_f32
    quantization.zero_point = 0
    tens = create_const_tensor(ifm.name + "_scalar0", [], ifm.dtype, [0], quantization=quantization)
    op.add_input_tensor(tens)

    # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
    # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
    # should be the same as the IFM
    op.forced_output_quantization = ifm.quantization

    # the lut tensor datatype needs to match both; the ofm datatype, because these are the values output; and the
    # datatype used to generate the lut values (which is probably the ifm datatype), because we want to avoid any
    # potential overflow errors in create_lut_tensor() caused by converting Python int (which could represent a uint)
    # to NumPy int. this can be guaranteed by checking that the ifm and ofm datatypes are the same
    assert ifm.dtype == ofm.dtype
    lut_tensor = create_lut_tensor(op.name + "_values", lut_values, ofm.dtype)
    op.set_activation_lut(lut_tensor)
    op.set_ifm_ofm_shapes()
    DebugDatabase.add_optimised(op, op)
    return op


def create_lut_8bit_op(op, lut_fn, fn_name):
    ifm_scale = op.ifm.quantization.scale_f32
    ofm_scale = op.ofm.quantization.scale_f32
    zp_in = op.ifm.quantization.zero_point
    zp_out = op.ofm.quantization.zero_point

    values = []
    ix = range(256) if op.ifm.dtype == DataType.uint8 else range(-128, 128)
    quantized_min = min(ix)
    quantized_max = max(ix)
    for x in ix:
        x_real = ifm_scale * (x - zp_in)
        y_real = lut_fn(x_real)
        lut_result = round_away_zero(y_real / ofm_scale) + zp_out
        lut_result = min(quantized_max, max(quantized_min, lut_result))
        values.append(lut_result)

    return convert_to_lut(op, values, fn_name)


def create_lut_int16_op(op, lut_fn, fn_name):
    ifm_scale = op.ifm.quantization.scale_f32
    ofm_scale = op.ofm.quantization.scale_f32
    zp_in = op.ifm.quantization.zero_point
    zp_out = op.ofm.quantization.zero_point

    input_min = ifm_scale * (np.iinfo(np.int16).min - zp_in)
    input_max = ifm_scale * (np.iinfo(np.int16).max - zp_in)
    output_min = ofm_scale * (np.iinfo(np.int16).min - zp_out)
    output_max = ofm_scale * (np.iinfo(np.int16).max - zp_out)

    # Create 16bit lut following the reference
    nbr_steps = 512
    step = (input_max - input_min) / nbr_steps
    half_step = step / 2
    output_scaling_inv = (np.iinfo(np.int16).max - np.iinfo(np.int16).min + 1) / (output_max - output_min)

    table_min = np.iinfo(np.int16).min
    table_max = np.iinfo(np.int16).max

    values = []
    for i in range(nbr_steps):
        val = lut_fn(input_min + i * step)
        val_midpoint = lut_fn(input_min + i * step + half_step)
        val_next = lut_fn(input_min + (i + 1) * step)

        sample_val = round_away_zero(val * output_scaling_inv)
        midpoint_interp_val = round_away_zero(
            (val_next * output_scaling_inv + round_away_zero(val * output_scaling_inv)) / 2
        )
        midpoint_val = round_away_zero(val_midpoint * output_scaling_inv)
        midpoint_err = midpoint_interp_val - midpoint_val
        bias = round_away_zero(midpoint_err / 2)

        lut_result = min(max(sample_val - bias, table_min), table_max)
        values.append(lut_result)

    val = round_away_zero(lut_fn(input_max) * output_scaling_inv)
    lut_result = min(max(val, table_min), table_max)
    values.append(lut_result)

    # Convert to hardware 16bit lut with base and slope
    lut = [0] * nbr_steps
    for i in range(nbr_steps):
        slope = (int(values[i + 1]) - int(values[i])) << 16
        base = int(values[i])
        lut[i] = slope + base

    return convert_to_lut(op, lut, fn_name)


def create_lut_rsqrt_int8_op(op):
    # Turn off black formatting for the LUT tables to keep them compact
    # fmt: off

    # RSQRT_LUT has been generated by printing the output from the reference.
    # These values are always the same but for some unknown reason it is not being
    # implemented as a LUT in the reference.
    # So based on the input range (-128, 127) the reference produces the following output:
    RSQRT_LUT = [
        0x00000000, 0x00100000, 0x000b504e, 0x00093cd4, 0x00080000, 0x000727c9, 0x0006882f, 0x00060c24,
        0x0005a827, 0x00055555, 0x00050f45, 0x0004d2fe, 0x00049e6a, 0x00047007, 0x000446b4, 0x00042195,
        0x00040000, 0x0003e16d, 0x0003c570, 0x0003abb0, 0x000393e5, 0x00037dd2, 0x00036945, 0x00035613,
        0x00034418, 0x00033333, 0x0003234b, 0x00031447, 0x00030612, 0x0002f89c, 0x0002ebd3, 0x0002dfaa,
        0x0002d414, 0x0002c906, 0x0002be75, 0x0002b45a, 0x0002aaab, 0x0002a161, 0x00029875, 0x00028fe3,
        0x000287a2, 0x00027fb0, 0x00027807, 0x000270a2, 0x0002697f, 0x00026298, 0x00025bec, 0x00025577,
        0x00024f35, 0x00024925, 0x00024343, 0x00023d8e, 0x00023803, 0x000232a1, 0x00022d65, 0x0002284e,
        0x0002235a, 0x00021e87, 0x000219d5, 0x00021541, 0x000210cb, 0x00020c70, 0x00020831, 0x0002040c,
        0x00020000, 0x0001fc0c, 0x0001f82f, 0x0001f468, 0x0001f0b7, 0x0001ed1a, 0x0001e991, 0x0001e61b,
        0x0001e2b8, 0x0001df67, 0x0001dc26, 0x0001d8f7, 0x0001d5d8, 0x0001d2c8, 0x0001cfc8, 0x0001ccd6,
        0x0001c9f2, 0x0001c71c, 0x0001c454, 0x0001c198, 0x0001bee9, 0x0001bc46, 0x0001b9af, 0x0001b723,
        0x0001b4a3, 0x0001b22d, 0x0001afc2, 0x0001ad61, 0x0001ab0a, 0x0001a8bc, 0x0001a678, 0x0001a43e,
        0x0001a20c, 0x00019fe3, 0x00019dc2, 0x00019baa, 0x0001999a, 0x00019791, 0x00019590, 0x00019397,
        0x000191a5, 0x00018fbb, 0x00018dd7, 0x00018bfa, 0x00018a23, 0x00018853, 0x0001868a, 0x000184c6,
        0x00018309, 0x00018152, 0x00017fa0, 0x00017df4, 0x00017c4e, 0x00017aad, 0x00017911, 0x0001777b,
        0x000175e9, 0x0001745d, 0x000172d6, 0x00017153, 0x00016fd5, 0x00016e5b, 0x00016ce7, 0x00016b76,
        0x00016a0a, 0x000168a2, 0x0001673e, 0x000165de, 0x00016483, 0x0001632b, 0x000161d7, 0x00016087,
        0x00015f3b, 0x00015df2, 0x00015cad, 0x00015b6b, 0x00015a2d, 0x000158f2, 0x000157bb, 0x00015686,
        0x00015555, 0x00015427, 0x000152fd, 0x000151d5, 0x000150b0, 0x00014f8f, 0x00014e70, 0x00014d54,
        0x00014c3b, 0x00014b24, 0x00014a11, 0x00014900, 0x000147f1, 0x000146e5, 0x000145dc, 0x000144d5,
        0x000143d1, 0x000142cf, 0x000141d0, 0x000140d3, 0x00013fd8, 0x00013ee0, 0x00013de9, 0x00013cf5,
        0x00013c03, 0x00013b14, 0x00013a26, 0x0001393b, 0x00013851, 0x0001376a, 0x00013684, 0x000135a1,
        0x000134bf, 0x000133e0, 0x00013302, 0x00013226, 0x0001314c, 0x00013074, 0x00012f9e, 0x00012ec9,
        0x00012df6, 0x00012d25, 0x00012c55, 0x00012b87, 0x00012abb, 0x000129f1, 0x00012928, 0x00012860,
        0x0001279a, 0x000126d6, 0x00012613, 0x00012552, 0x00012492, 0x000123d4, 0x00012317, 0x0001225c,
        0x000121a2, 0x000120e9, 0x00012032, 0x00011f7c, 0x00011ec7, 0x00011e14, 0x00011d62, 0x00011cb1,
        0x00011c02, 0x00011b54, 0x00011aa7, 0x000119fb, 0x00011950, 0x000118a7, 0x000117ff, 0x00011758,
        0x000116b3, 0x0001160e, 0x0001156b, 0x000114c8, 0x00011427, 0x00011387, 0x000112e8, 0x0001124a,
        0x000111ad, 0x00011111, 0x00011076, 0x00010fdc, 0x00010f44, 0x00010eac, 0x00010e15, 0x00010d7f,
        0x00010cea, 0x00010c56, 0x00010bc4, 0x00010b32, 0x00010aa0, 0x00010a10, 0x00010981, 0x000108f3,
        0x00010865, 0x000107d9, 0x0001074d, 0x000106c2, 0x00010638, 0x000105af, 0x00010527, 0x0001049f,
        0x00010419, 0x00010393, 0x0001030e, 0x0001028a, 0x00010206, 0x00010183, 0x00010102, 0x00010080
    ]

    # Transform the above LUT so it gets the correct quantization (following the reference)
    ifm_scale = op.ifm.quantization.scale_f32
    ofm_scale = op.ofm.quantization.scale_f32
    zp_in = op.ifm.quantization.zero_point
    zp_out = op.ofm.quantization.zero_point

    scale = np.double(1) / np.double(np.sqrt(ifm_scale) * ofm_scale)
    output_multiplier, output_shift = quantise_scale(scale)

    # Shift modification (value used in reference but Vela has opposite sign)
    kshift = -20

    ix = range(-128, 128)
    quantized_min = min(ix)
    quantized_max = max(ix)

    # Any value close to 0 (zero index in LUT) is mapped to the max output value
    values = [quantized_max]
    for x in ix:
        if x == -128:
            # Value already populated above
            continue
        # Rsqrt is only defined for positive values
        x_real = max(0, x - zp_in)
        val = RSQRT_LUT[x_real]
        val = fp_math.multiply_by_quantized_multiplier(val, output_multiplier, output_shift - kshift) + zp_out
        lut_result = min(quantized_max, max(quantized_min, val))
        values.append(lut_result)

    return convert_to_lut(op, values, "rsqrt")