From 0d008f77b0085619c446d0ab5dc1228a80776706 Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Fri, 13 Mar 2020 14:56:05 +0000 Subject: COMPMID-3281: Implement QSYMM16 Layer Normalization for NEON QLSTM - Reference kernel is modified to use the same algorithm as NEON kernel. - NEON kernel is implemented. - Tests for validation and run are added. Change-Id: I3533bc2bd12c6e9cc75d837ecf193f74ceddf796 Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2948 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- tests/validation/NEON/QLSTMLayerNormalization.cpp | 220 +++++++++++++++++++++ .../fixtures/QLSTMLayerNormalizationFixture.h | 143 ++++++++++++++ .../reference/QLSTMLayerNormalization.cpp | 74 +++---- 3 files changed, 403 insertions(+), 34 deletions(-) create mode 100644 tests/validation/NEON/QLSTMLayerNormalization.cpp create mode 100644 tests/validation/fixtures/QLSTMLayerNormalizationFixture.h (limited to 'tests') diff --git a/tests/validation/NEON/QLSTMLayerNormalization.cpp b/tests/validation/NEON/QLSTMLayerNormalization.cpp new file mode 100644 index 0000000000..8508a6e483 --- /dev/null +++ b/tests/validation/NEON/QLSTMLayerNormalization.cpp @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2020 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/runtime/TensorAllocator.h" +#include "tests/NEON/Accessor.h" +#include "tests/PaddingCalculator.h" +#include "tests/datasets/ShapeDatasets.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Macros.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Helpers.h" +#include "tests/validation/Validation.h" +#include "tests/validation/fixtures/QLSTMLayerNormalizationFixture.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +namespace +{ +constexpr uint32_t vector_size_byte = 16; + +using test::datasets::ShapeDataset; +template +class QLSTMLayerNormShapeDataSet : public ShapeDataset +{ + static constexpr auto boundary_minus_one = num_elements_per_iter * num_iteration - 1; + static constexpr auto boundary = num_elements_per_iter * num_iteration; + static constexpr auto boundary_plus_one = num_elements_per_iter * num_iteration + 1; + +public: + QLSTMLayerNormShapeDataSet(std::string name) + : ShapeDataset(name, + { + TensorShape{ boundary_minus_one, num_batches }, + TensorShape{ boundary, num_batches }, + TensorShape{ boundary_plus_one, num_batches } + }) + { + } +}; + +template +class QLSTMLayerNormShapeDataSet : public ShapeDataset +{ +public: + QLSTMLayerNormShapeDataSet(std::string name) + : ShapeDataset(name, + { + TensorShape{ 1, num_batches }, + TensorShape{ 2, num_batches } + }) + { + } +}; +} // namespace +TEST_SUITE(NEON) +TEST_SUITE(QLSTMLayerNormalization) + +static const TensorShape correct_input_shape{ TensorShape(15U, 2U) }; +static const TensorShape correct_weight_shape{ TensorShape(15U) }; +static const TensorShape correct_bias_shape{ TensorShape(15U) }; +static const TensorShape correct_output_shape{ correct_input_shape }; +static const DataType correct_input_dt{ DataType::QSYMM16 }; +static const DataType correct_weight_dt{ DataType::QSYMM16 }; +static const DataType correct_bias_dt{ DataType::S32 }; +static const DataType correct_output_dt{ correct_input_dt }; +static const uint32_t tensor_num_channel{ 1 }; + +// *INDENT-OFF* +// clang-format off + +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, + zip(zip(zip( + framework::dataset::make("InputInfo", { + TensorInfo(correct_input_shape, tensor_num_channel, DataType::F16), // input supports only QSYMM16 + TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight supports only QSYMM16 + TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // bias supports only S32 + TensorInfo(TensorShape(15U, 2U, 2U), tensor_num_channel, correct_input_dt), // input supports only up to 2D + TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight supports only up to 1D + TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // bias supports only up to 1D + TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // input_shape[0] != weight_shape[0] should fail + TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight_shape[0] != bias_shape[0] should fail + TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // output shape mismatches with input shape + TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // output data type mismatches with input data type + }), + framework::dataset::make("WeightInfo", { + TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt), + TensorInfo(correct_weight_shape, tensor_num_channel, DataType::F16), + TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt), + TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt), + TensorInfo(TensorShape(15U, 2U), tensor_num_channel, correct_weight_dt), + TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt), + TensorInfo(TensorShape(14U), tensor_num_channel, correct_weight_dt), + TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt), + TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt), + TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt), + }) + ), + framework::dataset::make("BiasInfo", { + TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt), + TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt), + TensorInfo(correct_bias_shape, tensor_num_channel, DataType::QSYMM16), + TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt), + TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt), + TensorInfo(TensorShape(15U, 2U), tensor_num_channel, correct_bias_dt), + TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt), + TensorInfo(TensorShape(14U), tensor_num_channel, correct_bias_dt), + TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt), + TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt), + }) + ), + framework::dataset::make("OutputInfo", { + TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt), + TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt), + TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt), + TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt), + TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt), + TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt), + TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt), + TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt), + TensorInfo(TensorShape(15, 3), tensor_num_channel, correct_output_dt), + TensorInfo(correct_output_shape, tensor_num_channel, DataType::S32), + }) + ), + input_info, weight_info, bias_info, output_info) +{ + const Status s = NEQLSTMLayerNormalizationKernel::validate(&input_info, &output_info, &weight_info, &bias_info); + ARM_COMPUTE_EXPECT(!bool(s), framework::LogLevel::ERRORS); +} + +// clang-format on +// *INDENT-ON* + +template +using NEQLSTMLayerNormalizationFixture = QLSTMLayerNormalizationValidationFixture; + +TEST_SUITE(Quantized) +TEST_SUITE(QSYMM16) + +/** Tests will be targetting + * - Comparison between NEON kernel and the exact same but scalar version of reference kernel + * - Input shapes of 1D and 2D with the first dimension covers boundary values of 128-bit vector size (0~3 iterations) + * - Weight and bias 1D shape that have same size as that of input shapes + * - Quantization scale is greater and smaller than one. + * - Input values will be noted in fixture. + * + * What we can't test + * - Since reference kernel uses the exact the same algorithm in the same quantized domain + * it is hard to fully test whether the algorithm accomplishes what it is supposed to. + * - The algorithm has been sensitive to quantization scale but it is hard to fully test + * the sensitivity due to aforementioned reason. + * - Again, it is hard to fully test corner values due to the exact same algorithm of the + * reference kernel and the NEON kernel. + */ + +constexpr uint32_t qsymm16_per_vector = vector_size_byte / sizeof(int16_t); + +#define QSYMM16_DATASET_ITER(num_input_batch, num_iter) \ + combine(combine(zip(zip(QLSTMLayerNormShapeDataSet("InputShape"), \ + QLSTMLayerNormShapeDataSet("WeightShape")), \ + QLSTMLayerNormShapeDataSet("BiasShape")), \ + framework::dataset::make("DataType", DataType::QSYMM16)), \ + framework::dataset::make("WeightQuantizationInfo", { QuantizationInfo(1. / 8192), QuantizationInfo(8192) })) + +#define QSYMM16_DATASET_1D \ + concat(concat(QSYMM16_DATASET_ITER(1, 0), QSYMM16_DATASET_ITER(1, 1)), QSYMM16_DATASET_ITER(1, 2)) + +#define QSYMM16_DATASET_2D \ + concat(concat(QSYMM16_DATASET_ITER(3, 0), QSYMM16_DATASET_ITER(3, 1)), QSYMM16_DATASET_ITER(3, 2)) + +FIXTURE_DATA_TEST_CASE(RandomValue1D, NEQLSTMLayerNormalizationFixture, framework::DatasetMode::ALL, QSYMM16_DATASET_1D) +{ + // Validate output + validate(Accessor(_target), _reference); +} + +FIXTURE_DATA_TEST_CASE(RandomValue2D, NEQLSTMLayerNormalizationFixture, framework::DatasetMode::ALL, QSYMM16_DATASET_2D) +{ + // Validate output + validate(Accessor(_target), _reference); +} + +#undef QSYMM16_DATASET_ITER +#undef QSYMM16_DATASET_2D +#undef QSYMM16_DATASET_1D + +TEST_SUITE_END() // QSYMM16 +TEST_SUITE_END() // Quantized +TEST_SUITE_END() // QLSTMLayerNormalization +TEST_SUITE_END() // NEON + +} // namespace validation +} // namespace test +} // namespace arm_compute diff --git a/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h b/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h new file mode 100644 index 0000000000..5d2cd2bd55 --- /dev/null +++ b/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2020 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE +#define ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE + +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" +#include "tests/AssetsLibrary.h" +#include "tests/Globals.h" +#include "tests/IAccessor.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Fixture.h" +#include "tests/validation/Helpers.h" +#include "tests/validation/reference/QLSTMLayerNormalization.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +template +class QLSTMLayerNormalizationValidationFixture : public framework::Fixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape weight_shape, TensorShape bias_shape, DataType data_type, QuantizationInfo weight_qinfo) + { + ARM_COMPUTE_ERROR_ON(data_type != DataType::QSYMM16); + + _data_type = data_type; + _qinfo = weight_qinfo; + + _target = compute_target(input_shape, weight_shape, bias_shape); + _reference = compute_reference(input_shape, weight_shape, bias_shape); + } + +protected: + template + void fill(InputType &&input_tensor, InputType &&weight_tensor, BiasType &&bias_tensor) + { + switch(_data_type) + { + case DataType::QSYMM16: + { + // Value ranges are based on reference implementation's test case. + constexpr int16_t input_min = -1000; + constexpr int16_t input_max = 1000; + constexpr int16_t weight_min = 19000; + constexpr int16_t weight_max = 27000; + constexpr int32_t bias_min = -16000000; + constexpr int32_t bias_max = -13000000; + + std::uniform_int_distribution<> input_distribution(input_min, input_max); + std::uniform_int_distribution<> weight_distribution(weight_min, weight_max); + std::uniform_int_distribution<> bias_distribution(bias_min, bias_max); + + library->fill(input_tensor, input_distribution, 0); + library->fill(weight_tensor, weight_distribution, 0); + library->fill(bias_tensor, bias_distribution, 0); + break; + } + default: + ARM_COMPUTE_ERROR("non-supported data type"); + break; + } + } + + void allocate_tensors(const std::vector &tensors) + { + for(auto t : tensors) + { + ARM_COMPUTE_EXPECT(t->info()->is_resizable(), framework::LogLevel::ERRORS); + t->allocator()->allocate(); + ARM_COMPUTE_EXPECT(!t->info()->is_resizable(), framework::LogLevel::ERRORS); + } + } + + TensorType compute_target(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape) + { + TensorType input = create_tensor(input_shape, _data_type, 1); + TensorType weight = create_tensor(weight_shape, _data_type, 1, _qinfo); + TensorType bias = create_tensor(bias_shape, DataType::S32, 1); + TensorType output = create_tensor(input_shape, _data_type, 1); + + FunctionType fn; + fn.configure(&input, &output, &weight, &bias); + allocate_tensors({ &input, &weight, &bias, &output }); + fill(AccessorType(input), AccessorType(weight), AccessorType(bias)); + + ThreadInfo tinfo; + tinfo.cpu_info = &NEScheduler::get().cpu_info(); + fn.run(fn.window(), tinfo); + + return output; + } + + SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape) + { + // Create reference + SimpleTensor input{ input_shape, _data_type, 1 }; + SimpleTensor weight{ weight_shape, _data_type, 1, _qinfo }; + SimpleTensor bias{ bias_shape, DataType::S32, 1 }; + + // Fill reference + fill(input, weight, bias); + + return reference::qlstm_layer_normalization(input, weight, bias); + } + + TensorType _target{}; + SimpleTensor _reference{}; + DataType _data_type{}; + QuantizationInfo _qinfo{}; +}; + +} // namespace validation +} // namespace test +} // namespace arm_compute + +#endif /* ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE */ diff --git a/tests/validation/reference/QLSTMLayerNormalization.cpp b/tests/validation/reference/QLSTMLayerNormalization.cpp index 0e24de6584..dd6517f81f 100644 --- a/tests/validation/reference/QLSTMLayerNormalization.cpp +++ b/tests/validation/reference/QLSTMLayerNormalization.cpp @@ -26,10 +26,9 @@ #include "ArithmeticOperations.h" #include "MeanStdDevNormalizationLayer.h" #include "PixelWiseMultiplication.h" +#include "arm_compute/core/utils/misc/Utility.h" #include "src/core/utils/quantization/AsymmHelpers.cpp" -#include "support/ToolchainSupport.h" - namespace arm_compute { namespace test @@ -38,53 +37,60 @@ namespace validation { namespace reference { -SimpleTensor qlstm_layer_normalization_float_compute(SimpleTensor src, SimpleTensor weight, SimpleTensor bias) -{ - SimpleTensor output = mean_std_normalization_layer(src); - output = pixel_wise_multiplication(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, DataType::F32); - return arithmetic_operation(ArithmeticOperation::ADD, output, bias, DataType::F32, ConvertPolicy::SATURATE); -} - SimpleTensor qlstm_layer_normalization(const SimpleTensor &src, const SimpleTensor &weight, const SimpleTensor &bias) { ARM_COMPUTE_ERROR_ON(src.shape().num_dimensions() > 2); + SimpleTensor output{ src.shape(), DataType::QSYMM16 }; - SimpleTensor converted_src{ src.shape(), DataType::F32 }; - SimpleTensor converted_weight{ weight.shape(), DataType::F32 }; - SimpleTensor converted_bias{ bias.shape(), DataType::F32 }; - - const auto iq_info = src.quantization_info().uniform(); + const auto wq_info = weight.quantization_info().uniform(); int output_multiplier{}; int output_shift{}; - quantization::calculate_quantized_multiplier(iq_info.scale, &output_multiplier, &output_shift); - - const float layer_norm_scale = output_multiplier * std::pow(2, static_cast(output_shift - 31)); - const float bias_scale = std::pow(2., -10) * layer_norm_scale; + const auto s = quantization::calculate_quantized_multiplier(wq_info.scale, &output_multiplier, &output_shift); + output_shift *= -1; - for(int i = 0; i < src.num_elements(); i++) + if(!bool(s)) { - converted_src[i] = static_cast(src[i]); + output_multiplier = 0; + output_shift = 0; } - for(int i = 0; i < bias.num_elements(); i++) - { - converted_bias[i] = static_cast(bias[i]) * bias_scale; - } + const uint32_t num_batch = src.shape()[1]; + const uint32_t num_input = src.shape()[0]; - for(int i = 0; i < weight.num_elements(); i++) + for(uint32_t batch_idx = 0; batch_idx < num_batch; ++batch_idx) { - converted_weight[i] = weight[i] * layer_norm_scale; - } + int64_t sum{}; + int64_t sum_sq{}; - SimpleTensor output_float = qlstm_layer_normalization_float_compute(converted_src, converted_weight, converted_bias); - SimpleTensor output{ output_float.shape(), DataType::QSYMM16 }; + for(uint32_t input_idx = 0; input_idx < num_input; ++input_idx) + { + const auto index = batch_idx * num_input + input_idx; + const auto val = static_cast(src[index]); + sum += val; + sum_sq += val * val; + } - for(int i = 0; i < output.num_elements(); i++) - { - const auto output_val_s32 = static_cast(support::cpp11::round(output_float[i] * std::pow(2, 12))); - output[i] = utility::clamp(output_val_s32, std::numeric_limits::min()); - } + const auto temp = static_cast(0x100000) / num_input; + const auto mean = sum * 1024 / static_cast(num_input); + const auto variance = ((sum_sq * temp) - (mean * mean)) / 0x100000; + + int32_t stddev_invsqrt_mul{}; + int32_t stddev_invsqrt_shift{}; + quantization::get_invsqrt_quantized_multiplier_exp(variance, -1, stddev_invsqrt_mul, stddev_invsqrt_shift); + for(uint32_t input_idx = 0; input_idx < num_input; ++input_idx) + { + const auto index = batch_idx * num_input + input_idx; + const auto val = static_cast(src[index]); + const auto shifted = (val << 10) - mean; + const auto rescaled = quantization::multiply_by_quantized_multiplier(shifted, stddev_invsqrt_mul, stddev_invsqrt_shift); + const int64_t weighted = rescaled * weight[input_idx] + bias[input_idx]; + const auto reverse_shifted = static_cast((weighted + 512) >> 10); + auto out_val = quantization::multiply_by_quantized_multiplier(reverse_shifted, output_multiplier, output_shift + 12); + out_val = arm_compute::utility::clamp(out_val, std::numeric_limits::min()); + output[index] = static_cast(out_val); + } + } return output; } } // namespace reference -- cgit v1.2.1