aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2020-03-13 14:56:05 +0000
committerSang-Hoon Park <sang-hoon.park@arm.com>2020-04-07 09:00:09 +0000
commit0d008f77b0085619c446d0ab5dc1228a80776706 (patch)
treee1f6e91bf8da63e8ef98e11ab8eb6a6972a284f2
parent4df2cf3177129d10500d30056bf8404418f703d6 (diff)
downloadComputeLibrary-0d008f77b0085619c446d0ab5dc1228a80776706.tar.gz
COMPMID-3281: Implement QSYMM16 Layer Normalization for NEON QLSTM
- Reference kernel is modified to use the same algorithm as NEON kernel. - NEON kernel is implemented. - Tests for validation and run are added. Change-Id: I3533bc2bd12c6e9cc75d837ecf193f74ceddf796 Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2948 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
-rw-r--r--Android.bp1
-rw-r--r--arm_compute/core/NEON/NEKernels.h1
-rw-r--r--arm_compute/core/NEON/NESymm.h2
-rw-r--r--arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h135
-rw-r--r--arm_compute/core/NEON/wrapper/intrinsics/getlane.h17
-rw-r--r--arm_compute/core/utils/quantization/AsymmHelpers.h15
-rw-r--r--src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp316
-rw-r--r--src/core/utils/quantization/AsymmHelpers.cpp84
-rw-r--r--tests/validation/NEON/QLSTMLayerNormalization.cpp220
-rw-r--r--tests/validation/fixtures/QLSTMLayerNormalizationFixture.h143
-rw-r--r--tests/validation/reference/QLSTMLayerNormalization.cpp74
11 files changed, 966 insertions, 42 deletions
diff --git a/Android.bp b/Android.bp
index 0cb0b7770e..528467a44e 100644
--- a/Android.bp
+++ b/Android.bp
@@ -322,6 +322,7 @@ cc_library_static {
"src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp",
"src/core/NEON/kernels/NEPoolingLayerKernel.cpp",
"src/core/NEON/kernels/NEPriorBoxLayerKernel.cpp",
+ "src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp",
"src/core/NEON/kernels/NEQuantizationLayerKernel.cpp",
"src/core/NEON/kernels/NEROIAlignLayerKernel.cpp",
"src/core/NEON/kernels/NEROIPoolingLayerKernel.cpp",
diff --git a/arm_compute/core/NEON/NEKernels.h b/arm_compute/core/NEON/NEKernels.h
index d9f8f00c0b..38701f434a 100644
--- a/arm_compute/core/NEON/NEKernels.h
+++ b/arm_compute/core/NEON/NEKernels.h
@@ -120,6 +120,7 @@
#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
#include "arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h"
#include "arm_compute/core/NEON/kernels/NEPriorBoxLayerKernel.h"
+#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
#include "arm_compute/core/NEON/kernels/NEQuantizationLayerKernel.h"
#include "arm_compute/core/NEON/kernels/NEROIAlignLayerKernel.h"
#include "arm_compute/core/NEON/kernels/NEROIPoolingLayerKernel.h"
diff --git a/arm_compute/core/NEON/NESymm.h b/arm_compute/core/NEON/NESymm.h
index 0cc2a963cf..d6c5a7073a 100644
--- a/arm_compute/core/NEON/NESymm.h
+++ b/arm_compute/core/NEON/NESymm.h
@@ -239,7 +239,7 @@ inline qsymm16x8x2_t vquantize_qsymm16(const float32x4x4_t &qv, const UniformQua
*
* @return A neon vector holding the multiplied value
*/
-inline int32x4x2_t multiply_by_quantized_multipler_2row(int32x4x2_t input, int32_t qmul, int32_t shift)
+inline int32x4x2_t multiply_by_quantized_multiplier_2row(int32x4x2_t input, int32_t qmul, int32_t shift)
{
const auto left_shift = shift > 0 ? shift : 0;
const auto right_shift = shift > 0 ? 0 : -shift;
diff --git a/arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h b/arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h
new file mode 100644
index 0000000000..631de66cc2
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h
@@ -0,0 +1,135 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_NEQLSTMLAYERNORMALIZATIONKERNEL_H
+#define ARM_COMPUTE_NEQLSTMLAYERNORMALIZATIONKERNEL_H
+
+#include "arm_compute/core/NEON/INEKernel.h"
+#include <functional>
+
+namespace arm_compute
+{
+class ITensor;
+
+/** NEON kernel to perform layer normalization */
+class NEQLSTMLayerNormalizationKernel : public INEKernel
+{
+public:
+ const char *name() const override
+ {
+ return "NEQLSTMLayerNormalizationKernel";
+ }
+ /** Default constructor */
+ NEQLSTMLayerNormalizationKernel() = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NEQLSTMLayerNormalizationKernel(const NEQLSTMLayerNormalizationKernel &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NEQLSTMLayerNormalizationKernel &operator=(const NEQLSTMLayerNormalizationKernel &) = delete;
+ /** Default Move Constructor. */
+ NEQLSTMLayerNormalizationKernel(NEQLSTMLayerNormalizationKernel &&) = default;
+ /** Default move assignment operator */
+ NEQLSTMLayerNormalizationKernel &operator=(NEQLSTMLayerNormalizationKernel &&) = default;
+ /** Default destructor */
+ ~NEQLSTMLayerNormalizationKernel() = default;
+
+ /** Set the input and output tensors.
+ *
+ * @param[in] input Source tensor. Data types supported: QSYMM16.
+ * @param[out] output Destination tensor. Data types supported: Same as @p input.
+ * @param[in] weight Weight tensor. Data types supported: Same as @p input.
+ * @param[in] bias Bias tensor. Data types supported: S32
+ */
+ void configure(const ITensor *input, ITensor *output, const ITensor *weight, const ITensor *bias);
+ /** Static function to check if given info will lead to a valid configuration of @ref NEQLSTMLayerNormalizationKernel
+ *
+ * @param[in] input Source tensor info. Data types supported: QSYMM16.
+ * @param[in] output Destination tensor info. Data types supported: Same as @p input.
+ * @param[in] weight Weight tensor info. Data types supported: Same as @p input.
+ * @param[in] bias Bias tensor info. Data types supported: S32
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias);
+ // Inherited methods overridden:
+ void run(const Window &window, const ThreadInfo &info) override;
+
+private:
+ // constants
+ static constexpr uint32_t max_input_dimension{ 2 }; /**< The maximum input dimension supported */
+ static constexpr uint32_t max_weight_dimension{ 1 }; /**< The maximum weight dimension supported */
+ static constexpr uint32_t max_bias_dimension{ 1 }; /**< The maximum bias dimension supported */
+ static constexpr uint32_t vector_size_byte{ 16 }; /**< Computation vector size in byte */
+
+ using ComputeFuncType = std::function<void(NEQLSTMLayerNormalizationKernel &)>;
+
+ ComputeFuncType _fn{}; /**< Function pointer to computation function */
+
+ const ITensor *_input{ nullptr }; /**< Input tensor */
+ const ITensor *_weight{ nullptr }; /**< Weight tensor */
+ const ITensor *_bias{ nullptr }; /**< Bias tensor */
+ ITensor *_output{ nullptr }; /**< Output tensor */
+
+ int32_t _output_multiplier{}; /**< Multiplier for output values */
+ int32_t _output_shift{}; /**< Shift value for output values */
+
+ int32_t _window_start_x{}; /**< The beginning of x-axis iteration */
+ int32_t _window_end_x{}; /**< The end of x-axis iteration */
+ int32_t _window_step_x{}; /**< The size of x-axis iteration's step */
+
+ Window _inout_window{}; /**< Window for input and output tensor */
+ Window _weight_window{}; /**< Window for weight and bias tensor */
+
+ /** Function to configure initial windows for destination of computation
+ *
+ * @param[in] Target destination tensor to use for output window
+ *
+ * @return configured window
+ */
+ Window configure_window(ITensor *target);
+ // Function to compute for data type QSYMM16
+ void compute_qsymm16();
+ /** Function to compute summation and summation of squared input of the given input pointer
+ *
+ * @param[in] Input_ptr pointer to input array
+ *
+ */
+ std::pair<int64_t, int64_t> sum_qsymm16(const int16_t *input_ptr);
+ /** Function to normalize values using computed mean and standard deviation
+ *
+ * @param[in] input_ptr Pointer to input array
+ * @param[in] output_ptr Pointer to output array
+ * @param[in] weight_ptr Pointer to weight array
+ * @param[in] bias_ptr Pointer to bias array
+ * @param[in] mean Mean value
+ * @param[in] inv_std_mul Quantized multiplier for standard deviation
+ * @param[in] inv_std_shift Shift for standard deviation
+ *
+ */
+ void normalize_qasymm16(const int16_t *input_ptr,
+ int16_t *output_ptr,
+ const int16_t *weight_ptr,
+ const int32_t *bias_ptr,
+ int32_t mean, int32_t inv_std_mul, int32_t inv_std_shift);
+};
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_NEQLSTMLAYERNORMALIZATIONKERNEL_H */
diff --git a/arm_compute/core/NEON/wrapper/intrinsics/getlane.h b/arm_compute/core/NEON/wrapper/intrinsics/getlane.h
index 5cd390fee4..533bf63603 100644
--- a/arm_compute/core/NEON/wrapper/intrinsics/getlane.h
+++ b/arm_compute/core/NEON/wrapper/intrinsics/getlane.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -185,6 +185,20 @@ VGETLANE_IMPL_4(float16_t, float16x4_t, f16)
} \
}
+#define VGETQLANE_IMPL_2(stype, vtype, postfix) \
+ inline stype vgetlane(const vtype vector, const unsigned int lane) \
+ { \
+ switch(lane) \
+ { \
+ case 0: \
+ return vgetq_lane_##postfix(vector, 0); \
+ case 1: \
+ return vgetq_lane_##postfix(vector, 1); \
+ default: \
+ ARM_COMPUTE_ERROR("Invalid lane"); \
+ } \
+ }
+
VGETQLANE_IMPL_16(uint8_t, uint8x16_t, u8)
VGETQLANE_IMPL_16(int8_t, int8x16_t, s8)
VGETQLANE_IMPL_8(uint16_t, uint16x8_t, u16)
@@ -192,6 +206,7 @@ VGETQLANE_IMPL_8(int16_t, int16x8_t, s16)
VGETQLANE_IMPL_4(uint32_t, uint32x4_t, u32)
VGETQLANE_IMPL_4(int32_t, int32x4_t, s32)
VGETQLANE_IMPL_4(float, float32x4_t, f32)
+VGETQLANE_IMPL_2(int64_t, int64x2_t, s64)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
VGETQLANE_IMPL_8(float16_t, float16x8_t, f16)
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
diff --git a/arm_compute/core/utils/quantization/AsymmHelpers.h b/arm_compute/core/utils/quantization/AsymmHelpers.h
index 0f0ec72b60..a7bbf9b137 100644
--- a/arm_compute/core/utils/quantization/AsymmHelpers.h
+++ b/arm_compute/core/utils/quantization/AsymmHelpers.h
@@ -128,7 +128,7 @@ int32_t saturating_rounding_doubling_highmul(int32_t a, int32_t b);
*
* @return The multiplied value
*/
-int32_t multiply_by_quantized_multipler(int32_t input, int32_t qmul, int32_t shift);
+int32_t multiply_by_quantized_multiplier(int32_t input, int32_t qmul, int32_t shift);
/** Compute the value multiplied the power-of-two
*
@@ -137,7 +137,18 @@ int32_t multiply_by_quantized_multipler(int32_t input, int32_t qmul, int32_t shi
*
* @return The multiplied value
*/
-int32_t saturating_rounding_multiply_by_pow2(int exponent, int32_t v);
+int32_t saturating_rounding_multiply_by_pow2(int32_t exponent, int32_t v);
+
+/** Compute quantized multiplier and shift for the inverse square root of input.
+ * Using 3-bit fixed point and 5 iteration of Newton-Raphson method.
+ *
+ * @param[in] input Input to use
+ * @param[in] reverse_shift -1 to reverse the shift direction
+ * @param[out] output_inv_sqrt Quantized multiplier for inverse square root
+ * @param[out] output_shift Shift for inverse square root
+ *
+ */
+void get_invsqrt_quantized_multiplier_exp(int32_t input, int32_t reverse_shift, int32_t &output_inv_sqrt, int32_t &output_shift);
} // namespace quantization
} // namespace arm_compute
diff --git a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
new file mode 100644
index 0000000000..db2ff85db9
--- /dev/null
+++ b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
@@ -0,0 +1,316 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
+
+#include "arm_compute/core/CPP/Validate.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/NEON/NEFixedPoint.h"
+#include "arm_compute/core/NEON/NEMath.h"
+#include "arm_compute/core/NEON/NESymm.h"
+#include "arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+
+#include <map>
+
+namespace arm_compute
+{
+namespace
+{
+inline std::pair<int64_t, int64_t> compute_mean_variance(int64_t sum, int64_t sum_sq, uint32_t num_input)
+{
+ const auto temp = static_cast<int64_t>(0x100000) / num_input;
+ const auto mean = sum * 1024 / static_cast<int64_t>(num_input);
+ const int64_t variance = ((sum_sq * temp) - (mean * mean)) / 0x100000;
+
+ return std::make_pair(mean, variance);
+}
+
+inline int64x2x2_t mul_add(const int32x4_t &a, const int32x4_t &b, const int32x4_t &bias)
+{
+ using namespace wrapper;
+ const int64x2_t a_low = vmovl(vgetlow(a));
+ const int64x2_t a_high = vmovl(vgethigh(a));
+ const int64x2_t b_low = vmovl(vgetlow(b));
+ const int64x2_t b_high = vmovl(vgethigh(b));
+
+ const int64_t a_0 = vgetlane(a_low, 0);
+ const int64_t a_1 = vgetlane(a_low, 1);
+ const int64_t a_2 = vgetlane(a_high, 0);
+ const int64_t a_3 = vgetlane(a_high, 1);
+
+ const int64_t b_0 = vgetlane(b_low, 0);
+ const int64_t b_1 = vgetlane(b_low, 1);
+ const int64_t b_2 = vgetlane(b_high, 0);
+ const int64_t b_3 = vgetlane(b_high, 1);
+
+ int64x2x2_t result;
+ const int64x2_t result_0{ a_0 * b_0, a_1 * b_1 };
+ const int64x2_t result_1{ a_2 * b_2, a_3 * b_3 };
+ result.val[0] = vadd(vmovl(vgetlow(bias)), result_0);
+ result.val[1] = vadd(vmovl(vgethigh(bias)), result_1);
+
+ return result;
+}
+} // namespace
+
+void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *output, const ITensor *weight, const ITensor *bias)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight);
+ ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(),
+ output ? output->info() : nullptr,
+ weight->info(),
+ bias ? bias->info() : nullptr));
+
+ static const std::map<DataType, ComputeFuncType> fn_map =
+ {
+ { DataType::QSYMM16, std::mem_fn(&NEQLSTMLayerNormalizationKernel::compute_qsymm16) },
+ };
+
+ _input = input;
+ _output = output;
+ _weight = weight;
+ _bias = bias;
+ _fn = fn_map.at(_input->info()->data_type());
+
+ auto_init_if_empty(*_output->info(), *_input->info());
+
+ const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform();
+ const Status s = quantization::calculate_quantized_multiplier(wq_info.scale, &_output_multiplier, &_output_shift);
+ _output_shift *= -1;
+
+ if(!bool(s))
+ {
+ _output_multiplier = 0;
+ _output_shift = 0;
+ }
+
+ Window win = configure_window(output);
+ INEKernel::configure(win);
+}
+
+Window NEQLSTMLayerNormalizationKernel::configure_window(ITensor *target)
+{
+ Window window = calculate_max_window(*target->info(), Steps());
+ Coordinates coord;
+ coord.set_num_dimensions(target->info()->num_dimensions());
+ target->info()->set_valid_region(ValidRegion(coord, target->info()->tensor_shape()));
+
+ _window_start_x = static_cast<int32_t>(window.x().start());
+ _window_end_x = static_cast<int32_t>(window.x().end());
+ _window_step_x = static_cast<int32_t>(vector_size_byte) / _output->info()->element_size();
+
+ // input and output windows will iterator over y-axis, while execute_window will handler x-axis.
+ _inout_window = window;
+ _inout_window.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // weight and bias cannot iterator along y-axis since they are 1D.
+ _weight_window = _inout_window;
+ _weight_window.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ return window;
+}
+
+Status NEQLSTMLayerNormalizationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias)
+{
+ ARM_COMPUTE_UNUSED(output, bias, weight, input);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight, bias, output);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QSYMM16);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weight, 1, DataType::QSYMM16);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > max_input_dimension);
+ ARM_COMPUTE_RETURN_ERROR_ON(weight->num_dimensions() > max_weight_dimension);
+ ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > max_bias_dimension);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().x() != weight->tensor_shape().x());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(weight, bias);
+
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ }
+
+ return Status{};
+}
+
+void NEQLSTMLayerNormalizationKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(window, info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+ ARM_COMPUTE_ERROR_ON_MSG(!_fn, "internal function is not defined for computation");
+
+ _fn(*this);
+}
+
+inline std::pair<int64_t, int64_t> NEQLSTMLayerNormalizationKernel::sum_qsymm16(const int16_t *input_ptr)
+{
+ ARM_COMPUTE_ERROR_ON(!input_ptr);
+
+ using AccType = int64_t;
+ using InputDataType = int16_t;
+
+ AccType sum{ 0 };
+ AccType sum_sq{ 0 };
+
+ int32_t x = _window_start_x;
+ for(; x <= _window_end_x && _window_step_x <= (_window_end_x - x); x += _window_step_x)
+ {
+ using namespace wrapper;
+ const int16x8_t val = vloadq(input_ptr + x);
+ const int32x4_t val_low = vmovl(vgetlow(val));
+ const int32x4_t val_high = vmovl(vgethigh(val));
+
+#if defined(__aarch64__)
+ sum += static_cast<AccType>(vaddv(val_low));
+ sum += static_cast<AccType>(vaddv(val_high));
+
+ sum_sq += static_cast<AccType>(vaddv(vmul(val_low, val_low)));
+ sum_sq += static_cast<AccType>(vaddv(vmul(val_high, val_high)));
+#else // __aarch64__
+ // only AArch64 supports vaddv
+ const int64x2_t pair_sum_low = vpaddl(val_low);
+ const int64x2_t pair_sum_high = vpaddl(val_high);
+ const int64x2_t pair_sum = vadd(pair_sum_low, pair_sum_high);
+ sum += vgetlane(pair_sum, 0) + vgetlane(pair_sum, 1);
+
+ const int32x4_t square_low = vmul(val_low, val_low);
+ const int32x4_t square_high = vmul(val_high, val_high);
+ const int64x2_t pair_sum_sq_low = vpaddl(square_low);
+ const int64x2_t pair_sum_sq_high = vpaddl(square_high);
+ const int64x2_t pair_sum_sq = vadd(pair_sum_sq_low, pair_sum_sq_high);
+ sum_sq += vgetlane(pair_sum_sq, 0) + vgetlane(pair_sum_sq, 1);
+#endif // __aarch64__
+ }
+
+ for(; x < _window_end_x; ++x)
+ {
+ const InputDataType val = input_ptr[x];
+ sum += static_cast<AccType>(val);
+ sum_sq += static_cast<AccType>(val * val);
+ }
+
+ return std::make_pair(sum, sum_sq);
+}
+
+inline void NEQLSTMLayerNormalizationKernel::normalize_qasymm16(const int16_t *input_ptr,
+ int16_t *output_ptr,
+ const int16_t *weight_ptr,
+ const int32_t *bias_ptr,
+ int32_t mean, int32_t inv_std_mul, int32_t inv_std_shift)
+{
+ using OutputDataType = int16_t;
+
+ using namespace wrapper;
+ const int32x4_t mean_vec = vdup_n(mean, wrapper::traits::vector_128_tag{});
+
+ int32_t x = _window_start_x;
+ for(; x <= _window_end_x && _window_step_x <= (_window_end_x - x); x += _window_step_x)
+ {
+ const int16x8_t val = vloadq(input_ptr + x);
+ int32x4x2_t shifted;
+ shifted.val[0] = vsub(vshlq_n_s32(vmovl(vgetlow(val)), 10), mean_vec);
+ shifted.val[1] = vsub(vshlq_n_s32(vmovl(vgethigh(val)), 10), mean_vec);
+
+ int32x4x2_t rescaled = multiply_by_quantized_multiplier_2row(shifted, inv_std_mul, inv_std_shift);
+
+ const int16x8_t weight_val = vloadq(weight_ptr + x);
+ const int32x4_t weight_low = vmovl(vgetlow(weight_val));
+ const int32x4_t weight_high = vmovl(vgethigh(weight_val));
+
+ const int32x4_t bias_low = vloadq(bias_ptr + x);
+ const int32x4_t bias_high = vloadq(bias_ptr + 4 + x);
+
+ int64x2x2_t result_0 = mul_add(rescaled.val[0], weight_low, bias_low);
+ int64x2x2_t result_1 = mul_add(rescaled.val[1], weight_high, bias_high);
+
+ int32x4x2_t combined;
+ combined.val[0] = vcombine(vmovn(vrshrq_n_s64(result_0.val[0], 10)), vmovn(vrshrq_n_s64(result_0.val[1], 10)));
+ combined.val[1] = vcombine(vmovn(vrshrq_n_s64(result_1.val[0], 10)), vmovn(vrshrq_n_s64(result_1.val[1], 10)));
+
+ int32x4x2_t out_val = multiply_by_quantized_multiplier_2row(combined, _output_multiplier, _output_shift + 12);
+
+ vstore(output_ptr + x, vqmovn(out_val.val[0]));
+ vstore(output_ptr + x + 4, vqmovn(out_val.val[1]));
+ }
+
+ for(; x < _window_end_x; ++x)
+ {
+ const auto val = static_cast<int32_t>(input_ptr[x]);
+ const int32_t shifted = (val << 10) - mean;
+ const int32_t rescaled = quantization::multiply_by_quantized_multiplier(shifted, inv_std_mul, inv_std_shift);
+ const int64_t weighted = rescaled * weight_ptr[x] + bias_ptr[x];
+ const auto reverse_shifted = static_cast<int32_t>((weighted + 512) >> 10);
+ int32_t out_val = quantization::multiply_by_quantized_multiplier(reverse_shifted, _output_multiplier, _output_shift + 12);
+ out_val = utility::clamp<decltype(out_val), OutputDataType>(out_val, std::numeric_limits<OutputDataType>::min());
+ output_ptr[x] = static_cast<OutputDataType>(out_val);
+ }
+}
+
+void NEQLSTMLayerNormalizationKernel::compute_qsymm16()
+{
+ using InputDataType = int16_t;
+ using OutputDataType = int16_t;
+ using BiasDataType = int32_t;
+ using AccType = int64_t;
+
+ Iterator input_iterator{ _input, _inout_window };
+ Iterator output_iterator{ _output, _inout_window };
+ Iterator weight_iterator{ _weight, _weight_window };
+ Iterator bias_iterator{ _bias, _weight_window };
+
+ const auto weight_ptr = reinterpret_cast<const InputDataType *>(weight_iterator.ptr());
+ const auto bias_ptr = reinterpret_cast<const BiasDataType *>(bias_iterator.ptr());
+
+ const uint32_t column_size = _input->info()->tensor_shape()[0];
+
+ execute_window_loop(_inout_window, [ &, this](const Coordinates &)
+ {
+ const auto in_ptr = reinterpret_cast<const InputDataType *>(input_iterator.ptr());
+ auto out_ptr = reinterpret_cast<OutputDataType *>(output_iterator.ptr());
+
+ AccType sum{ 0 };
+ AccType sum_sq{ 0 };
+ std::tie(sum, sum_sq) = sum_qsymm16(in_ptr);
+
+ AccType mean{ 0 };
+ AccType variance{ 0 };
+ std::tie(mean, variance) = compute_mean_variance(sum, sum_sq, column_size);
+
+ int32_t stddev_invsqrt_mul{};
+ int32_t stddev_invsqrt_shift{};
+ quantization::get_invsqrt_quantized_multiplier_exp(static_cast<int32_t>(variance), -1, stddev_invsqrt_mul, stddev_invsqrt_shift);
+
+ normalize_qasymm16(in_ptr, out_ptr, weight_ptr, bias_ptr, mean, stddev_invsqrt_mul, stddev_invsqrt_shift);
+ },
+ input_iterator, output_iterator);
+}
+} // namespace arm_compute \ No newline at end of file
diff --git a/src/core/utils/quantization/AsymmHelpers.cpp b/src/core/utils/quantization/AsymmHelpers.cpp
index c5eef9dd77..f923518ca4 100644
--- a/src/core/utils/quantization/AsymmHelpers.cpp
+++ b/src/core/utils/quantization/AsymmHelpers.cpp
@@ -202,9 +202,10 @@ int32_t saturating_rounding_doubling_highmul(int32_t a, int32_t b)
bool overflow = a == b && a == std::numeric_limits<int32_t>::min();
int64_t a_64(a);
int64_t b_64(b);
- int64_t ab_64 = a_64 * b_64;
- int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
- int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31));
+ int64_t ab_64 = a_64 * b_64;
+ bool is_positive_or_zero = a == 0 || b == 0 || (std::signbit(a) == std::signbit(b));
+ int32_t nudge = is_positive_or_zero ? (1 << 30) : (1 - (1 << 30));
+ int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31));
return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32;
}
@@ -215,7 +216,7 @@ inline int32_t rounding_divide_by_pow2(int32_t x, int exponent)
return (x >> exponent) + ((x & mask) > threshold ? 1 : 0);
}
-int32_t multiply_by_quantized_multipler(int32_t input, int32_t qmul, int32_t shift)
+int32_t multiply_by_quantized_multiplier(int32_t input, int32_t qmul, int32_t shift)
{
const auto left_shift = shift > 0 ? shift : 0;
const auto right_shift = shift > 0 ? 0 : -shift;
@@ -247,5 +248,80 @@ int32_t saturating_rounding_multiply_by_pow2(int32_t exponent, int32_t v)
return result;
}
}
+
+void get_invsqrt_quantized_multiplier_exp(int32_t input, int32_t reverse_shift, int32_t &output_inv_sqrt, int32_t &output_shift)
+{
+ ARM_COMPUTE_ERROR_ON(input < 0);
+
+ if(input <= 1)
+ {
+ // dealing the inputs (0 and 1) separately to avoid overflow
+ output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
+ output_shift = 0;
+ return;
+ }
+
+ // prepare input for fixed point operation and compute shift value
+ output_shift = 11;
+ while(input >= (1 << 29))
+ {
+ input /= 4;
+ ++output_shift;
+ }
+
+ const uint32_t max_left_shift_bits = __builtin_clz(static_cast<uint32_t>(input)) - 1;
+ const uint32_t max_left_shift_bits_pairs = max_left_shift_bits / 2;
+ const uint32_t left_shift_bit_pairs = max_left_shift_bits_pairs - 1;
+ output_shift -= left_shift_bit_pairs;
+ input <<= 2 * left_shift_bit_pairs;
+
+ // Calculation in fixed point domain with 3 integer bits.
+ using FixedPointRawType = int32_t;
+ constexpr uint32_t fixedpoint_position = 3;
+ constexpr uint32_t fixedpoint_int_position = sizeof(FixedPointRawType) * 8 - 1 - fixedpoint_position;
+ using FixedPoint3 = FixedPointRawType;
+ using FixedPoint0 = FixedPointRawType;
+
+ // fixed point representation of input divided by 2 and 1.5 for Newton-Raphson iteration
+ const FixedPoint3 fixedpoint_input = (input >> 1);
+ const FixedPoint3 fixedpoint_half_input = rounding_divide_by_pow2(fixedpoint_input, 1);
+ const FixedPoint3 fixedpoint_half_three = (0x1 << fixedpoint_int_position) + (0x1 << (fixedpoint_int_position - 1));
+
+ // initial guess (1) in fixed point representation
+ FixedPoint3 x = 0x1 << fixedpoint_int_position;
+
+ // multiplication of two fixed point numbers, defined for readability
+ auto fixed_point_mul = [](FixedPointRawType a, FixedPointRawType b) -> FixedPointRawType
+ {
+ return saturating_rounding_doubling_highmul(a, b);
+ };
+
+ // rescaling of fixed point to have dst_bit integer bits, defined for readability
+ auto fixed_point_rescale = [](FixedPointRawType a, uint32_t src_bit, uint32_t dst_bit) -> FixedPointRawType
+ {
+ const uint32_t exponent = src_bit - dst_bit;
+ return saturating_rounding_multiply_by_pow2(exponent, a);
+ };
+
+ // 5 iterations of Newton-Raphson method for inverse square root - 1.5 * x_n = input/2 * (x_n)^3
+ constexpr int32_t num_iteration = 5;
+ for(int32_t i = 0; i < num_iteration; ++i)
+ {
+ const auto x3 = fixed_point_rescale(fixed_point_mul(fixed_point_mul(x, x), x), 9, fixedpoint_position);
+ x = fixed_point_rescale(fixed_point_mul(fixedpoint_half_three, x) - fixed_point_mul(fixedpoint_half_input, x3), 6, fixedpoint_position);
+ }
+
+ // fixed point representation of sqrt(1/2)
+ const FixedPoint0 fixedpoint_half_sqrt_2 = 1518500250;
+ x = fixed_point_mul(fixedpoint_half_sqrt_2, x);
+ output_inv_sqrt = x;
+ if(output_shift < 0)
+ {
+ output_inv_sqrt <<= -output_shift;
+ output_shift = 0;
+ }
+ // convert right shift to left shift
+ output_shift *= reverse_shift;
+}
} // quantization
} // arm_compute
diff --git a/tests/validation/NEON/QLSTMLayerNormalization.cpp b/tests/validation/NEON/QLSTMLayerNormalization.cpp
new file mode 100644
index 0000000000..8508a6e483
--- /dev/null
+++ b/tests/validation/NEON/QLSTMLayerNormalization.cpp
@@ -0,0 +1,220 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/Tensor.h"
+#include "arm_compute/runtime/TensorAllocator.h"
+#include "tests/NEON/Accessor.h"
+#include "tests/PaddingCalculator.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Helpers.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/fixtures/QLSTMLayerNormalizationFixture.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+constexpr uint32_t vector_size_byte = 16;
+
+using test::datasets::ShapeDataset;
+template <uint32_t num_elements_per_iter, uint32_t num_batches, uint32_t num_iteration>
+class QLSTMLayerNormShapeDataSet : public ShapeDataset
+{
+ static constexpr auto boundary_minus_one = num_elements_per_iter * num_iteration - 1;
+ static constexpr auto boundary = num_elements_per_iter * num_iteration;
+ static constexpr auto boundary_plus_one = num_elements_per_iter * num_iteration + 1;
+
+public:
+ QLSTMLayerNormShapeDataSet(std::string name)
+ : ShapeDataset(name,
+ {
+ TensorShape{ boundary_minus_one, num_batches },
+ TensorShape{ boundary, num_batches },
+ TensorShape{ boundary_plus_one, num_batches }
+ })
+ {
+ }
+};
+
+template <uint32_t num_elements_per_iter, uint32_t num_batches>
+class QLSTMLayerNormShapeDataSet<num_elements_per_iter, num_batches, 0> : public ShapeDataset
+{
+public:
+ QLSTMLayerNormShapeDataSet(std::string name)
+ : ShapeDataset(name,
+ {
+ TensorShape{ 1, num_batches },
+ TensorShape{ 2, num_batches }
+ })
+ {
+ }
+};
+} // namespace
+TEST_SUITE(NEON)
+TEST_SUITE(QLSTMLayerNormalization)
+
+static const TensorShape correct_input_shape{ TensorShape(15U, 2U) };
+static const TensorShape correct_weight_shape{ TensorShape(15U) };
+static const TensorShape correct_bias_shape{ TensorShape(15U) };
+static const TensorShape correct_output_shape{ correct_input_shape };
+static const DataType correct_input_dt{ DataType::QSYMM16 };
+static const DataType correct_weight_dt{ DataType::QSYMM16 };
+static const DataType correct_bias_dt{ DataType::S32 };
+static const DataType correct_output_dt{ correct_input_dt };
+static const uint32_t tensor_num_channel{ 1 };
+
+// *INDENT-OFF*
+// clang-format off
+
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL,
+ zip(zip(zip(
+ framework::dataset::make("InputInfo", {
+ TensorInfo(correct_input_shape, tensor_num_channel, DataType::F16), // input supports only QSYMM16
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight supports only QSYMM16
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // bias supports only S32
+ TensorInfo(TensorShape(15U, 2U, 2U), tensor_num_channel, correct_input_dt), // input supports only up to 2D
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight supports only up to 1D
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // bias supports only up to 1D
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // input_shape[0] != weight_shape[0] should fail
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight_shape[0] != bias_shape[0] should fail
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // output shape mismatches with input shape
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // output data type mismatches with input data type
+ }),
+ framework::dataset::make("WeightInfo", {
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, DataType::F16),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(TensorShape(15U, 2U), tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(TensorShape(14U), tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ })
+ ),
+ framework::dataset::make("BiasInfo", {
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, DataType::QSYMM16),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(TensorShape(15U, 2U), tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(TensorShape(14U), tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ })
+ ),
+ framework::dataset::make("OutputInfo", {
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(TensorShape(15, 3), tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, DataType::S32),
+ })
+ ),
+ input_info, weight_info, bias_info, output_info)
+{
+ const Status s = NEQLSTMLayerNormalizationKernel::validate(&input_info, &output_info, &weight_info, &bias_info);
+ ARM_COMPUTE_EXPECT(!bool(s), framework::LogLevel::ERRORS);
+}
+
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using NEQLSTMLayerNormalizationFixture = QLSTMLayerNormalizationValidationFixture<Tensor, Accessor, NEQLSTMLayerNormalizationKernel, T>;
+
+TEST_SUITE(Quantized)
+TEST_SUITE(QSYMM16)
+
+/** Tests will be targetting
+ * - Comparison between NEON kernel and the exact same but scalar version of reference kernel
+ * - Input shapes of 1D and 2D with the first dimension covers boundary values of 128-bit vector size (0~3 iterations)
+ * - Weight and bias 1D shape that have same size as that of input shapes
+ * - Quantization scale is greater and smaller than one.
+ * - Input values will be noted in fixture.
+ *
+ * What we can't test
+ * - Since reference kernel uses the exact the same algorithm in the same quantized domain
+ * it is hard to fully test whether the algorithm accomplishes what it is supposed to.
+ * - The algorithm has been sensitive to quantization scale but it is hard to fully test
+ * the sensitivity due to aforementioned reason.
+ * - Again, it is hard to fully test corner values due to the exact same algorithm of the
+ * reference kernel and the NEON kernel.
+ */
+
+constexpr uint32_t qsymm16_per_vector = vector_size_byte / sizeof(int16_t);
+
+#define QSYMM16_DATASET_ITER(num_input_batch, num_iter) \
+ combine(combine(zip(zip(QLSTMLayerNormShapeDataSet<qsymm16_per_vector, num_input_batch, num_iter>("InputShape"), \
+ QLSTMLayerNormShapeDataSet<qsymm16_per_vector, 1, num_iter>("WeightShape")), \
+ QLSTMLayerNormShapeDataSet<qsymm16_per_vector, 1, num_iter>("BiasShape")), \
+ framework::dataset::make("DataType", DataType::QSYMM16)), \
+ framework::dataset::make("WeightQuantizationInfo", { QuantizationInfo(1. / 8192), QuantizationInfo(8192) }))
+
+#define QSYMM16_DATASET_1D \
+ concat(concat(QSYMM16_DATASET_ITER(1, 0), QSYMM16_DATASET_ITER(1, 1)), QSYMM16_DATASET_ITER(1, 2))
+
+#define QSYMM16_DATASET_2D \
+ concat(concat(QSYMM16_DATASET_ITER(3, 0), QSYMM16_DATASET_ITER(3, 1)), QSYMM16_DATASET_ITER(3, 2))
+
+FIXTURE_DATA_TEST_CASE(RandomValue1D, NEQLSTMLayerNormalizationFixture<int16_t>, framework::DatasetMode::ALL, QSYMM16_DATASET_1D)
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RandomValue2D, NEQLSTMLayerNormalizationFixture<int16_t>, framework::DatasetMode::ALL, QSYMM16_DATASET_2D)
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+#undef QSYMM16_DATASET_ITER
+#undef QSYMM16_DATASET_2D
+#undef QSYMM16_DATASET_1D
+
+TEST_SUITE_END() // QSYMM16
+TEST_SUITE_END() // Quantized
+TEST_SUITE_END() // QLSTMLayerNormalization
+TEST_SUITE_END() // NEON
+
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h b/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h
new file mode 100644
index 0000000000..5d2cd2bd55
--- /dev/null
+++ b/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h
@@ -0,0 +1,143 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE
+#define ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE
+
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "tests/AssetsLibrary.h"
+#include "tests/Globals.h"
+#include "tests/IAccessor.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
+#include "tests/validation/reference/QLSTMLayerNormalization.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class QLSTMLayerNormalizationValidationFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, TensorShape weight_shape, TensorShape bias_shape, DataType data_type, QuantizationInfo weight_qinfo)
+ {
+ ARM_COMPUTE_ERROR_ON(data_type != DataType::QSYMM16);
+
+ _data_type = data_type;
+ _qinfo = weight_qinfo;
+
+ _target = compute_target(input_shape, weight_shape, bias_shape);
+ _reference = compute_reference(input_shape, weight_shape, bias_shape);
+ }
+
+protected:
+ template <typename InputType, typename BiasType>
+ void fill(InputType &&input_tensor, InputType &&weight_tensor, BiasType &&bias_tensor)
+ {
+ switch(_data_type)
+ {
+ case DataType::QSYMM16:
+ {
+ // Value ranges are based on reference implementation's test case.
+ constexpr int16_t input_min = -1000;
+ constexpr int16_t input_max = 1000;
+ constexpr int16_t weight_min = 19000;
+ constexpr int16_t weight_max = 27000;
+ constexpr int32_t bias_min = -16000000;
+ constexpr int32_t bias_max = -13000000;
+
+ std::uniform_int_distribution<> input_distribution(input_min, input_max);
+ std::uniform_int_distribution<> weight_distribution(weight_min, weight_max);
+ std::uniform_int_distribution<> bias_distribution(bias_min, bias_max);
+
+ library->fill(input_tensor, input_distribution, 0);
+ library->fill(weight_tensor, weight_distribution, 0);
+ library->fill(bias_tensor, bias_distribution, 0);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("non-supported data type");
+ break;
+ }
+ }
+
+ void allocate_tensors(const std::vector<TensorType *> &tensors)
+ {
+ for(auto t : tensors)
+ {
+ ARM_COMPUTE_EXPECT(t->info()->is_resizable(), framework::LogLevel::ERRORS);
+ t->allocator()->allocate();
+ ARM_COMPUTE_EXPECT(!t->info()->is_resizable(), framework::LogLevel::ERRORS);
+ }
+ }
+
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
+ {
+ TensorType input = create_tensor<TensorType>(input_shape, _data_type, 1);
+ TensorType weight = create_tensor<TensorType>(weight_shape, _data_type, 1, _qinfo);
+ TensorType bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
+ TensorType output = create_tensor<TensorType>(input_shape, _data_type, 1);
+
+ FunctionType fn;
+ fn.configure(&input, &output, &weight, &bias);
+ allocate_tensors({ &input, &weight, &bias, &output });
+ fill(AccessorType(input), AccessorType(weight), AccessorType(bias));
+
+ ThreadInfo tinfo;
+ tinfo.cpu_info = &NEScheduler::get().cpu_info();
+ fn.run(fn.window(), tinfo);
+
+ return output;
+ }
+
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
+ {
+ // Create reference
+ SimpleTensor<T> input{ input_shape, _data_type, 1 };
+ SimpleTensor<T> weight{ weight_shape, _data_type, 1, _qinfo };
+ SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
+
+ // Fill reference
+ fill(input, weight, bias);
+
+ return reference::qlstm_layer_normalization(input, weight, bias);
+ }
+
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+ DataType _data_type{};
+ QuantizationInfo _qinfo{};
+};
+
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
+
+#endif /* ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE */
diff --git a/tests/validation/reference/QLSTMLayerNormalization.cpp b/tests/validation/reference/QLSTMLayerNormalization.cpp
index 0e24de6584..dd6517f81f 100644
--- a/tests/validation/reference/QLSTMLayerNormalization.cpp
+++ b/tests/validation/reference/QLSTMLayerNormalization.cpp
@@ -26,10 +26,9 @@
#include "ArithmeticOperations.h"
#include "MeanStdDevNormalizationLayer.h"
#include "PixelWiseMultiplication.h"
+#include "arm_compute/core/utils/misc/Utility.h"
#include "src/core/utils/quantization/AsymmHelpers.cpp"
-#include "support/ToolchainSupport.h"
-
namespace arm_compute
{
namespace test
@@ -38,53 +37,60 @@ namespace validation
{
namespace reference
{
-SimpleTensor<float> qlstm_layer_normalization_float_compute(SimpleTensor<float> src, SimpleTensor<float> weight, SimpleTensor<float> bias)
-{
- SimpleTensor<float> output = mean_std_normalization_layer(src);
- output = pixel_wise_multiplication<float, float, float>(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, DataType::F32);
- return arithmetic_operation(ArithmeticOperation::ADD, output, bias, DataType::F32, ConvertPolicy::SATURATE);
-}
-
SimpleTensor<int16_t> qlstm_layer_normalization(const SimpleTensor<int16_t> &src, const SimpleTensor<int16_t> &weight, const SimpleTensor<int32_t> &bias)
{
ARM_COMPUTE_ERROR_ON(src.shape().num_dimensions() > 2);
+ SimpleTensor<int16_t> output{ src.shape(), DataType::QSYMM16 };
- SimpleTensor<float> converted_src{ src.shape(), DataType::F32 };
- SimpleTensor<float> converted_weight{ weight.shape(), DataType::F32 };
- SimpleTensor<float> converted_bias{ bias.shape(), DataType::F32 };
-
- const auto iq_info = src.quantization_info().uniform();
+ const auto wq_info = weight.quantization_info().uniform();
int output_multiplier{};
int output_shift{};
- quantization::calculate_quantized_multiplier(iq_info.scale, &output_multiplier, &output_shift);
-
- const float layer_norm_scale = output_multiplier * std::pow(2, static_cast<double>(output_shift - 31));
- const float bias_scale = std::pow(2., -10) * layer_norm_scale;
+ const auto s = quantization::calculate_quantized_multiplier(wq_info.scale, &output_multiplier, &output_shift);
+ output_shift *= -1;
- for(int i = 0; i < src.num_elements(); i++)
+ if(!bool(s))
{
- converted_src[i] = static_cast<float>(src[i]);
+ output_multiplier = 0;
+ output_shift = 0;
}
- for(int i = 0; i < bias.num_elements(); i++)
- {
- converted_bias[i] = static_cast<float>(bias[i]) * bias_scale;
- }
+ const uint32_t num_batch = src.shape()[1];
+ const uint32_t num_input = src.shape()[0];
- for(int i = 0; i < weight.num_elements(); i++)
+ for(uint32_t batch_idx = 0; batch_idx < num_batch; ++batch_idx)
{
- converted_weight[i] = weight[i] * layer_norm_scale;
- }
+ int64_t sum{};
+ int64_t sum_sq{};
- SimpleTensor<float> output_float = qlstm_layer_normalization_float_compute(converted_src, converted_weight, converted_bias);
- SimpleTensor<int16_t> output{ output_float.shape(), DataType::QSYMM16 };
+ for(uint32_t input_idx = 0; input_idx < num_input; ++input_idx)
+ {
+ const auto index = batch_idx * num_input + input_idx;
+ const auto val = static_cast<int32_t>(src[index]);
+ sum += val;
+ sum_sq += val * val;
+ }
- for(int i = 0; i < output.num_elements(); i++)
- {
- const auto output_val_s32 = static_cast<int32_t>(support::cpp11::round(output_float[i] * std::pow(2, 12)));
- output[i] = utility::clamp<int32_t, int16_t>(output_val_s32, std::numeric_limits<int16_t>::min());
- }
+ const auto temp = static_cast<int64_t>(0x100000) / num_input;
+ const auto mean = sum * 1024 / static_cast<int64_t>(num_input);
+ const auto variance = ((sum_sq * temp) - (mean * mean)) / 0x100000;
+
+ int32_t stddev_invsqrt_mul{};
+ int32_t stddev_invsqrt_shift{};
+ quantization::get_invsqrt_quantized_multiplier_exp(variance, -1, stddev_invsqrt_mul, stddev_invsqrt_shift);
+ for(uint32_t input_idx = 0; input_idx < num_input; ++input_idx)
+ {
+ const auto index = batch_idx * num_input + input_idx;
+ const auto val = static_cast<int32_t>(src[index]);
+ const auto shifted = (val << 10) - mean;
+ const auto rescaled = quantization::multiply_by_quantized_multiplier(shifted, stddev_invsqrt_mul, stddev_invsqrt_shift);
+ const int64_t weighted = rescaled * weight[input_idx] + bias[input_idx];
+ const auto reverse_shifted = static_cast<int32_t>((weighted + 512) >> 10);
+ auto out_val = quantization::multiply_by_quantized_multiplier(reverse_shifted, output_multiplier, output_shift + 12);
+ out_val = arm_compute::utility::clamp<decltype(out_val), int16_t>(out_val, std::numeric_limits<int16_t>::min());
+ output[index] = static_cast<int16_t>(out_val);
+ }
+ }
return output;
}
} // namespace reference