aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
blob: cbfbda71e284100b2e1156c2990b9122f045949d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
/*
 * Copyright (c) 2020 Arm Limited.
 *
 * SPDX-License-Identifier: MIT
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"

#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
#include "arm_compute/core/NEON/NEMath.h"
#include "arm_compute/core/NEON/NESymm.h"
#include "arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"

#include <map>

namespace arm_compute
{
namespace
{
inline std::pair<int64_t, int64_t> compute_mean_variance(int64_t sum, int64_t sum_sq, uint32_t num_input)
{
    const auto    temp     = static_cast<int64_t>(0x100000) / num_input;
    const auto    mean     = sum * 1024 / static_cast<int64_t>(num_input);
    const int64_t variance = ((sum_sq * temp) - (mean * mean)) / 0x100000;

    return std::make_pair(mean, variance);
}

inline int64x2x2_t mul_add(const int32x4_t &a, const int32x4_t &b, const int32x4_t &bias)
{
    using namespace wrapper;
    const int64x2_t a_low  = vmovl(vgetlow(a));
    const int64x2_t a_high = vmovl(vgethigh(a));
    const int64x2_t b_low  = vmovl(vgetlow(b));
    const int64x2_t b_high = vmovl(vgethigh(b));

    const int64_t a_0 = vgetlane(a_low, 0);
    const int64_t a_1 = vgetlane(a_low, 1);
    const int64_t a_2 = vgetlane(a_high, 0);
    const int64_t a_3 = vgetlane(a_high, 1);

    const int64_t b_0 = vgetlane(b_low, 0);
    const int64_t b_1 = vgetlane(b_low, 1);
    const int64_t b_2 = vgetlane(b_high, 0);
    const int64_t b_3 = vgetlane(b_high, 1);

    int64x2x2_t     result;
    const int64x2_t result_0{ a_0 * b_0, a_1 * b_1 };
    const int64x2_t result_1{ a_2 * b_2, a_3 * b_3 };
    result.val[0] = vadd(vmovl(vgetlow(bias)), result_0);
    result.val[1] = vadd(vmovl(vgethigh(bias)), result_1);

    return result;
}
} // namespace

void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *output, const ITensor *weight, const ITensor *bias)
{
    ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight, bias, output);
    ARM_COMPUTE_ERROR_ON(input == output);
    ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), weight->info(), bias->info()));

    static const std::map<DataType, ComputeFuncType> fn_map =
    {
        { DataType::QSYMM16, std::mem_fn(&NEQLSTMLayerNormalizationKernel::compute_qsymm16) },
    };

    _input  = input;
    _output = output;
    _weight = weight;
    _bias   = bias;
    _fn     = fn_map.at(_input->info()->data_type());

    auto_init_if_empty(*_output->info(), *_input->info());
    _output->info()->set_quantization_info(compute_output_qinfo());

    const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform();
    const Status                  s       = quantization::calculate_quantized_multiplier(wq_info.scale, &_output_multiplier, &_output_shift);
    _output_shift *= -1;

    if(!bool(s))
    {
        _output_multiplier = 0;
        _output_shift      = 0;
    }

    Window win = configure_window(output);
    INEKernel::configure(win);
}

Window NEQLSTMLayerNormalizationKernel::configure_window(ITensor *target)
{
    Window      window = calculate_max_window(*target->info(), Steps());
    Coordinates coord;
    coord.set_num_dimensions(target->info()->num_dimensions());
    target->info()->set_valid_region(ValidRegion(coord, target->info()->tensor_shape()));

    _window_start_x = static_cast<int32_t>(window.x().start());
    _window_end_x   = static_cast<int32_t>(window.x().end());
    _window_step_x  = static_cast<int32_t>(vector_size_byte) / _output->info()->element_size();

    // input and output windows will iterator over y-axis, while execute_window will handler x-axis.
    _inout_window = window;
    _inout_window.set(Window::DimX, Window::Dimension(0, 1, 1));

    // weight and bias cannot iterator along y-axis since they are 1D.
    _weight_window = _inout_window;
    _weight_window.set(Window::DimY, Window::Dimension(0, 1, 1));

    return window;
}

Status NEQLSTMLayerNormalizationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias)
{
    ARM_COMPUTE_UNUSED(output, bias, weight, input);

    ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight, bias, output);

    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QSYMM16);
    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weight, 1, DataType::QSYMM16);
    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32);

    ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > max_input_dimension);
    ARM_COMPUTE_RETURN_ERROR_ON(weight->num_dimensions() > max_weight_dimension);
    ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > max_bias_dimension);

    ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().x() != weight->tensor_shape().x());
    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(weight, bias);

    if(output->total_size() != 0)
    {
        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
    }

    return Status{};
}

void NEQLSTMLayerNormalizationKernel::run(const Window &window, const ThreadInfo &info)
{
    ARM_COMPUTE_UNUSED(window, info);
    ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
    ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
    ARM_COMPUTE_ERROR_ON_MSG(!_fn, "internal function is not defined for computation");

    _fn(*this);
}

inline QuantizationInfo NEQLSTMLayerNormalizationKernel::compute_output_qinfo()
{
    return QuantizationInfo(1.f / 4096);
}

inline std::pair<int64_t, int64_t> NEQLSTMLayerNormalizationKernel::sum_qsymm16(const int16_t *input_ptr)
{
    ARM_COMPUTE_ERROR_ON(!input_ptr);

    using AccType       = int64_t;
    using InputDataType = int16_t;

    AccType sum{ 0 };
    AccType sum_sq{ 0 };

    int32_t x = _window_start_x;
    for(; x <= _window_end_x && _window_step_x <= (_window_end_x - x); x += _window_step_x)
    {
        using namespace wrapper;
        const int16x8_t val      = vloadq(input_ptr + x);
        const int32x4_t val_low  = vmovl(vgetlow(val));
        const int32x4_t val_high = vmovl(vgethigh(val));

#if defined(__aarch64__)
        sum += static_cast<AccType>(vaddv(val_low));
        sum += static_cast<AccType>(vaddv(val_high));

        sum_sq += static_cast<AccType>(vaddv(vmul(val_low, val_low)));
        sum_sq += static_cast<AccType>(vaddv(vmul(val_high, val_high)));
#else  // __aarch64__
        // only AArch64 supports vaddv
        const int64x2_t pair_sum_low  = vpaddl(val_low);
        const int64x2_t pair_sum_high = vpaddl(val_high);
        const int64x2_t pair_sum      = vadd(pair_sum_low, pair_sum_high);
        sum += vgetlane(pair_sum, 0) + vgetlane(pair_sum, 1);

        const int32x4_t square_low       = vmul(val_low, val_low);
        const int32x4_t square_high      = vmul(val_high, val_high);
        const int64x2_t pair_sum_sq_low  = vpaddl(square_low);
        const int64x2_t pair_sum_sq_high = vpaddl(square_high);
        const int64x2_t pair_sum_sq      = vadd(pair_sum_sq_low, pair_sum_sq_high);
        sum_sq += vgetlane(pair_sum_sq, 0) + vgetlane(pair_sum_sq, 1);
#endif // __aarch64__
    }

    for(; x < _window_end_x; ++x)
    {
        const InputDataType val = input_ptr[x];
        sum += static_cast<AccType>(val);
        sum_sq += static_cast<AccType>(val * val);
    }

    return std::make_pair(sum, sum_sq);
}

inline void NEQLSTMLayerNormalizationKernel::normalize_qasymm16(const int16_t *input_ptr,
                                                                int16_t       *output_ptr,
                                                                const int16_t *weight_ptr,
                                                                const int32_t *bias_ptr,
                                                                int32_t mean, int32_t inv_std_mul, int32_t inv_std_shift)
{
    using OutputDataType = int16_t;

    using namespace wrapper;
    const int32x4_t mean_vec = vdup_n(mean, wrapper::traits::vector_128_tag{});

    int32_t x = _window_start_x;
    for(; x <= _window_end_x && _window_step_x <= (_window_end_x - x); x += _window_step_x)
    {
        const int16x8_t val = vloadq(input_ptr + x);
        int32x4x2_t     shifted;
        shifted.val[0] = vsub(vshlq_n_s32(vmovl(vgetlow(val)), 10), mean_vec);
        shifted.val[1] = vsub(vshlq_n_s32(vmovl(vgethigh(val)), 10), mean_vec);

        int32x4x2_t rescaled = multiply_by_quantized_multiplier_2row(shifted, inv_std_mul, inv_std_shift);

        const int16x8_t weight_val  = vloadq(weight_ptr + x);
        const int32x4_t weight_low  = vmovl(vgetlow(weight_val));
        const int32x4_t weight_high = vmovl(vgethigh(weight_val));

        const int32x4_t bias_low  = vloadq(bias_ptr + x);
        const int32x4_t bias_high = vloadq(bias_ptr + 4 + x);

        int64x2x2_t result_0 = mul_add(rescaled.val[0], weight_low, bias_low);
        int64x2x2_t result_1 = mul_add(rescaled.val[1], weight_high, bias_high);

        int32x4x2_t combined;
        combined.val[0] = vcombine(vmovn(vrshrq_n_s64(result_0.val[0], 10)), vmovn(vrshrq_n_s64(result_0.val[1], 10)));
        combined.val[1] = vcombine(vmovn(vrshrq_n_s64(result_1.val[0], 10)), vmovn(vrshrq_n_s64(result_1.val[1], 10)));

        int32x4x2_t out_val = multiply_by_quantized_multiplier_2row(combined, _output_multiplier, _output_shift + 12);

        vstore(output_ptr + x, vqmovn(out_val.val[0]));
        vstore(output_ptr + x + 4, vqmovn(out_val.val[1]));
    }

    for(; x < _window_end_x; ++x)
    {
        const auto    val             = static_cast<int32_t>(input_ptr[x]);
        const int32_t shifted         = (val << 10) - mean;
        const int32_t rescaled        = quantization::multiply_by_quantized_multiplier(shifted, inv_std_mul, inv_std_shift);
        const int64_t weighted        = rescaled * weight_ptr[x] + bias_ptr[x];
        const auto    reverse_shifted = static_cast<int32_t>((weighted + 512) >> 10);
        int32_t       out_val         = quantization::multiply_by_quantized_multiplier(reverse_shifted, _output_multiplier, _output_shift + 12);
        out_val                       = utility::clamp<decltype(out_val), OutputDataType>(out_val, std::numeric_limits<OutputDataType>::min());
        output_ptr[x]                 = static_cast<OutputDataType>(out_val);
    }
}

void NEQLSTMLayerNormalizationKernel::compute_qsymm16()
{
    using InputDataType  = int16_t;
    using OutputDataType = int16_t;
    using BiasDataType   = int32_t;
    using AccType        = int64_t;

    Iterator input_iterator{ _input, _inout_window };
    Iterator output_iterator{ _output, _inout_window };
    Iterator weight_iterator{ _weight, _weight_window };
    Iterator bias_iterator{ _bias, _weight_window };

    const auto weight_ptr = reinterpret_cast<const InputDataType *>(weight_iterator.ptr());
    const auto bias_ptr   = reinterpret_cast<const BiasDataType *>(bias_iterator.ptr());

    const uint32_t column_size = _input->info()->tensor_shape()[0];

    execute_window_loop(_inout_window, [ &, this](const Coordinates &)
    {
        const auto in_ptr  = reinterpret_cast<const InputDataType *>(input_iterator.ptr());
        auto       out_ptr = reinterpret_cast<OutputDataType *>(output_iterator.ptr());

        AccType sum{ 0 };
        AccType sum_sq{ 0 };
        std::tie(sum, sum_sq) = sum_qsymm16(in_ptr);

        AccType mean{ 0 };
        AccType variance{ 0 };
        std::tie(mean, variance) = compute_mean_variance(sum, sum_sq, column_size);

        int32_t stddev_invsqrt_mul{};
        int32_t stddev_invsqrt_shift{};
        quantization::get_invsqrt_quantized_multiplier_exp(static_cast<int32_t>(variance), -1, stddev_invsqrt_mul, stddev_invsqrt_shift);

        normalize_qasymm16(in_ptr, out_ptr, weight_ptr, bias_ptr, mean, stddev_invsqrt_mul, stddev_invsqrt_shift);
    },
    input_iterator, output_iterator);
}
} // namespace arm_compute