aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2020-03-13 14:56:05 +0000
committerSang-Hoon Park <sang-hoon.park@arm.com>2020-04-07 09:00:09 +0000
commit0d008f77b0085619c446d0ab5dc1228a80776706 (patch)
treee1f6e91bf8da63e8ef98e11ab8eb6a6972a284f2 /tests
parent4df2cf3177129d10500d30056bf8404418f703d6 (diff)
downloadComputeLibrary-0d008f77b0085619c446d0ab5dc1228a80776706.tar.gz
COMPMID-3281: Implement QSYMM16 Layer Normalization for NEON QLSTM
- Reference kernel is modified to use the same algorithm as NEON kernel. - NEON kernel is implemented. - Tests for validation and run are added. Change-Id: I3533bc2bd12c6e9cc75d837ecf193f74ceddf796 Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2948 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/NEON/QLSTMLayerNormalization.cpp220
-rw-r--r--tests/validation/fixtures/QLSTMLayerNormalizationFixture.h143
-rw-r--r--tests/validation/reference/QLSTMLayerNormalization.cpp74
3 files changed, 403 insertions, 34 deletions
diff --git a/tests/validation/NEON/QLSTMLayerNormalization.cpp b/tests/validation/NEON/QLSTMLayerNormalization.cpp
new file mode 100644
index 0000000000..8508a6e483
--- /dev/null
+++ b/tests/validation/NEON/QLSTMLayerNormalization.cpp
@@ -0,0 +1,220 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/Tensor.h"
+#include "arm_compute/runtime/TensorAllocator.h"
+#include "tests/NEON/Accessor.h"
+#include "tests/PaddingCalculator.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Helpers.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/fixtures/QLSTMLayerNormalizationFixture.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+constexpr uint32_t vector_size_byte = 16;
+
+using test::datasets::ShapeDataset;
+template <uint32_t num_elements_per_iter, uint32_t num_batches, uint32_t num_iteration>
+class QLSTMLayerNormShapeDataSet : public ShapeDataset
+{
+ static constexpr auto boundary_minus_one = num_elements_per_iter * num_iteration - 1;
+ static constexpr auto boundary = num_elements_per_iter * num_iteration;
+ static constexpr auto boundary_plus_one = num_elements_per_iter * num_iteration + 1;
+
+public:
+ QLSTMLayerNormShapeDataSet(std::string name)
+ : ShapeDataset(name,
+ {
+ TensorShape{ boundary_minus_one, num_batches },
+ TensorShape{ boundary, num_batches },
+ TensorShape{ boundary_plus_one, num_batches }
+ })
+ {
+ }
+};
+
+template <uint32_t num_elements_per_iter, uint32_t num_batches>
+class QLSTMLayerNormShapeDataSet<num_elements_per_iter, num_batches, 0> : public ShapeDataset
+{
+public:
+ QLSTMLayerNormShapeDataSet(std::string name)
+ : ShapeDataset(name,
+ {
+ TensorShape{ 1, num_batches },
+ TensorShape{ 2, num_batches }
+ })
+ {
+ }
+};
+} // namespace
+TEST_SUITE(NEON)
+TEST_SUITE(QLSTMLayerNormalization)
+
+static const TensorShape correct_input_shape{ TensorShape(15U, 2U) };
+static const TensorShape correct_weight_shape{ TensorShape(15U) };
+static const TensorShape correct_bias_shape{ TensorShape(15U) };
+static const TensorShape correct_output_shape{ correct_input_shape };
+static const DataType correct_input_dt{ DataType::QSYMM16 };
+static const DataType correct_weight_dt{ DataType::QSYMM16 };
+static const DataType correct_bias_dt{ DataType::S32 };
+static const DataType correct_output_dt{ correct_input_dt };
+static const uint32_t tensor_num_channel{ 1 };
+
+// *INDENT-OFF*
+// clang-format off
+
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL,
+ zip(zip(zip(
+ framework::dataset::make("InputInfo", {
+ TensorInfo(correct_input_shape, tensor_num_channel, DataType::F16), // input supports only QSYMM16
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight supports only QSYMM16
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // bias supports only S32
+ TensorInfo(TensorShape(15U, 2U, 2U), tensor_num_channel, correct_input_dt), // input supports only up to 2D
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight supports only up to 1D
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // bias supports only up to 1D
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // input_shape[0] != weight_shape[0] should fail
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // weight_shape[0] != bias_shape[0] should fail
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // output shape mismatches with input shape
+ TensorInfo(correct_input_shape, tensor_num_channel, correct_input_dt), // output data type mismatches with input data type
+ }),
+ framework::dataset::make("WeightInfo", {
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, DataType::F16),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(TensorShape(15U, 2U), tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(TensorShape(14U), tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ TensorInfo(correct_weight_shape, tensor_num_channel, correct_weight_dt),
+ })
+ ),
+ framework::dataset::make("BiasInfo", {
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, DataType::QSYMM16),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(TensorShape(15U, 2U), tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(TensorShape(14U), tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ TensorInfo(correct_bias_shape, tensor_num_channel, correct_bias_dt),
+ })
+ ),
+ framework::dataset::make("OutputInfo", {
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, correct_output_dt),
+ TensorInfo(TensorShape(15, 3), tensor_num_channel, correct_output_dt),
+ TensorInfo(correct_output_shape, tensor_num_channel, DataType::S32),
+ })
+ ),
+ input_info, weight_info, bias_info, output_info)
+{
+ const Status s = NEQLSTMLayerNormalizationKernel::validate(&input_info, &output_info, &weight_info, &bias_info);
+ ARM_COMPUTE_EXPECT(!bool(s), framework::LogLevel::ERRORS);
+}
+
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using NEQLSTMLayerNormalizationFixture = QLSTMLayerNormalizationValidationFixture<Tensor, Accessor, NEQLSTMLayerNormalizationKernel, T>;
+
+TEST_SUITE(Quantized)
+TEST_SUITE(QSYMM16)
+
+/** Tests will be targetting
+ * - Comparison between NEON kernel and the exact same but scalar version of reference kernel
+ * - Input shapes of 1D and 2D with the first dimension covers boundary values of 128-bit vector size (0~3 iterations)
+ * - Weight and bias 1D shape that have same size as that of input shapes
+ * - Quantization scale is greater and smaller than one.
+ * - Input values will be noted in fixture.
+ *
+ * What we can't test
+ * - Since reference kernel uses the exact the same algorithm in the same quantized domain
+ * it is hard to fully test whether the algorithm accomplishes what it is supposed to.
+ * - The algorithm has been sensitive to quantization scale but it is hard to fully test
+ * the sensitivity due to aforementioned reason.
+ * - Again, it is hard to fully test corner values due to the exact same algorithm of the
+ * reference kernel and the NEON kernel.
+ */
+
+constexpr uint32_t qsymm16_per_vector = vector_size_byte / sizeof(int16_t);
+
+#define QSYMM16_DATASET_ITER(num_input_batch, num_iter) \
+ combine(combine(zip(zip(QLSTMLayerNormShapeDataSet<qsymm16_per_vector, num_input_batch, num_iter>("InputShape"), \
+ QLSTMLayerNormShapeDataSet<qsymm16_per_vector, 1, num_iter>("WeightShape")), \
+ QLSTMLayerNormShapeDataSet<qsymm16_per_vector, 1, num_iter>("BiasShape")), \
+ framework::dataset::make("DataType", DataType::QSYMM16)), \
+ framework::dataset::make("WeightQuantizationInfo", { QuantizationInfo(1. / 8192), QuantizationInfo(8192) }))
+
+#define QSYMM16_DATASET_1D \
+ concat(concat(QSYMM16_DATASET_ITER(1, 0), QSYMM16_DATASET_ITER(1, 1)), QSYMM16_DATASET_ITER(1, 2))
+
+#define QSYMM16_DATASET_2D \
+ concat(concat(QSYMM16_DATASET_ITER(3, 0), QSYMM16_DATASET_ITER(3, 1)), QSYMM16_DATASET_ITER(3, 2))
+
+FIXTURE_DATA_TEST_CASE(RandomValue1D, NEQLSTMLayerNormalizationFixture<int16_t>, framework::DatasetMode::ALL, QSYMM16_DATASET_1D)
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RandomValue2D, NEQLSTMLayerNormalizationFixture<int16_t>, framework::DatasetMode::ALL, QSYMM16_DATASET_2D)
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+#undef QSYMM16_DATASET_ITER
+#undef QSYMM16_DATASET_2D
+#undef QSYMM16_DATASET_1D
+
+TEST_SUITE_END() // QSYMM16
+TEST_SUITE_END() // Quantized
+TEST_SUITE_END() // QLSTMLayerNormalization
+TEST_SUITE_END() // NEON
+
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h b/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h
new file mode 100644
index 0000000000..5d2cd2bd55
--- /dev/null
+++ b/tests/validation/fixtures/QLSTMLayerNormalizationFixture.h
@@ -0,0 +1,143 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE
+#define ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE
+
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "tests/AssetsLibrary.h"
+#include "tests/Globals.h"
+#include "tests/IAccessor.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
+#include "tests/validation/reference/QLSTMLayerNormalization.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class QLSTMLayerNormalizationValidationFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, TensorShape weight_shape, TensorShape bias_shape, DataType data_type, QuantizationInfo weight_qinfo)
+ {
+ ARM_COMPUTE_ERROR_ON(data_type != DataType::QSYMM16);
+
+ _data_type = data_type;
+ _qinfo = weight_qinfo;
+
+ _target = compute_target(input_shape, weight_shape, bias_shape);
+ _reference = compute_reference(input_shape, weight_shape, bias_shape);
+ }
+
+protected:
+ template <typename InputType, typename BiasType>
+ void fill(InputType &&input_tensor, InputType &&weight_tensor, BiasType &&bias_tensor)
+ {
+ switch(_data_type)
+ {
+ case DataType::QSYMM16:
+ {
+ // Value ranges are based on reference implementation's test case.
+ constexpr int16_t input_min = -1000;
+ constexpr int16_t input_max = 1000;
+ constexpr int16_t weight_min = 19000;
+ constexpr int16_t weight_max = 27000;
+ constexpr int32_t bias_min = -16000000;
+ constexpr int32_t bias_max = -13000000;
+
+ std::uniform_int_distribution<> input_distribution(input_min, input_max);
+ std::uniform_int_distribution<> weight_distribution(weight_min, weight_max);
+ std::uniform_int_distribution<> bias_distribution(bias_min, bias_max);
+
+ library->fill(input_tensor, input_distribution, 0);
+ library->fill(weight_tensor, weight_distribution, 0);
+ library->fill(bias_tensor, bias_distribution, 0);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("non-supported data type");
+ break;
+ }
+ }
+
+ void allocate_tensors(const std::vector<TensorType *> &tensors)
+ {
+ for(auto t : tensors)
+ {
+ ARM_COMPUTE_EXPECT(t->info()->is_resizable(), framework::LogLevel::ERRORS);
+ t->allocator()->allocate();
+ ARM_COMPUTE_EXPECT(!t->info()->is_resizable(), framework::LogLevel::ERRORS);
+ }
+ }
+
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
+ {
+ TensorType input = create_tensor<TensorType>(input_shape, _data_type, 1);
+ TensorType weight = create_tensor<TensorType>(weight_shape, _data_type, 1, _qinfo);
+ TensorType bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
+ TensorType output = create_tensor<TensorType>(input_shape, _data_type, 1);
+
+ FunctionType fn;
+ fn.configure(&input, &output, &weight, &bias);
+ allocate_tensors({ &input, &weight, &bias, &output });
+ fill(AccessorType(input), AccessorType(weight), AccessorType(bias));
+
+ ThreadInfo tinfo;
+ tinfo.cpu_info = &NEScheduler::get().cpu_info();
+ fn.run(fn.window(), tinfo);
+
+ return output;
+ }
+
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
+ {
+ // Create reference
+ SimpleTensor<T> input{ input_shape, _data_type, 1 };
+ SimpleTensor<T> weight{ weight_shape, _data_type, 1, _qinfo };
+ SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
+
+ // Fill reference
+ fill(input, weight, bias);
+
+ return reference::qlstm_layer_normalization(input, weight, bias);
+ }
+
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+ DataType _data_type{};
+ QuantizationInfo _qinfo{};
+};
+
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
+
+#endif /* ARM_COMPUTE_TEST_QLSTM_LAYER_NORMALIZATION_FIXTURE */
diff --git a/tests/validation/reference/QLSTMLayerNormalization.cpp b/tests/validation/reference/QLSTMLayerNormalization.cpp
index 0e24de6584..dd6517f81f 100644
--- a/tests/validation/reference/QLSTMLayerNormalization.cpp
+++ b/tests/validation/reference/QLSTMLayerNormalization.cpp
@@ -26,10 +26,9 @@
#include "ArithmeticOperations.h"
#include "MeanStdDevNormalizationLayer.h"
#include "PixelWiseMultiplication.h"
+#include "arm_compute/core/utils/misc/Utility.h"
#include "src/core/utils/quantization/AsymmHelpers.cpp"
-#include "support/ToolchainSupport.h"
-
namespace arm_compute
{
namespace test
@@ -38,53 +37,60 @@ namespace validation
{
namespace reference
{
-SimpleTensor<float> qlstm_layer_normalization_float_compute(SimpleTensor<float> src, SimpleTensor<float> weight, SimpleTensor<float> bias)
-{
- SimpleTensor<float> output = mean_std_normalization_layer(src);
- output = pixel_wise_multiplication<float, float, float>(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, DataType::F32);
- return arithmetic_operation(ArithmeticOperation::ADD, output, bias, DataType::F32, ConvertPolicy::SATURATE);
-}
-
SimpleTensor<int16_t> qlstm_layer_normalization(const SimpleTensor<int16_t> &src, const SimpleTensor<int16_t> &weight, const SimpleTensor<int32_t> &bias)
{
ARM_COMPUTE_ERROR_ON(src.shape().num_dimensions() > 2);
+ SimpleTensor<int16_t> output{ src.shape(), DataType::QSYMM16 };
- SimpleTensor<float> converted_src{ src.shape(), DataType::F32 };
- SimpleTensor<float> converted_weight{ weight.shape(), DataType::F32 };
- SimpleTensor<float> converted_bias{ bias.shape(), DataType::F32 };
-
- const auto iq_info = src.quantization_info().uniform();
+ const auto wq_info = weight.quantization_info().uniform();
int output_multiplier{};
int output_shift{};
- quantization::calculate_quantized_multiplier(iq_info.scale, &output_multiplier, &output_shift);
-
- const float layer_norm_scale = output_multiplier * std::pow(2, static_cast<double>(output_shift - 31));
- const float bias_scale = std::pow(2., -10) * layer_norm_scale;
+ const auto s = quantization::calculate_quantized_multiplier(wq_info.scale, &output_multiplier, &output_shift);
+ output_shift *= -1;
- for(int i = 0; i < src.num_elements(); i++)
+ if(!bool(s))
{
- converted_src[i] = static_cast<float>(src[i]);
+ output_multiplier = 0;
+ output_shift = 0;
}
- for(int i = 0; i < bias.num_elements(); i++)
- {
- converted_bias[i] = static_cast<float>(bias[i]) * bias_scale;
- }
+ const uint32_t num_batch = src.shape()[1];
+ const uint32_t num_input = src.shape()[0];
- for(int i = 0; i < weight.num_elements(); i++)
+ for(uint32_t batch_idx = 0; batch_idx < num_batch; ++batch_idx)
{
- converted_weight[i] = weight[i] * layer_norm_scale;
- }
+ int64_t sum{};
+ int64_t sum_sq{};
- SimpleTensor<float> output_float = qlstm_layer_normalization_float_compute(converted_src, converted_weight, converted_bias);
- SimpleTensor<int16_t> output{ output_float.shape(), DataType::QSYMM16 };
+ for(uint32_t input_idx = 0; input_idx < num_input; ++input_idx)
+ {
+ const auto index = batch_idx * num_input + input_idx;
+ const auto val = static_cast<int32_t>(src[index]);
+ sum += val;
+ sum_sq += val * val;
+ }
- for(int i = 0; i < output.num_elements(); i++)
- {
- const auto output_val_s32 = static_cast<int32_t>(support::cpp11::round(output_float[i] * std::pow(2, 12)));
- output[i] = utility::clamp<int32_t, int16_t>(output_val_s32, std::numeric_limits<int16_t>::min());
- }
+ const auto temp = static_cast<int64_t>(0x100000) / num_input;
+ const auto mean = sum * 1024 / static_cast<int64_t>(num_input);
+ const auto variance = ((sum_sq * temp) - (mean * mean)) / 0x100000;
+
+ int32_t stddev_invsqrt_mul{};
+ int32_t stddev_invsqrt_shift{};
+ quantization::get_invsqrt_quantized_multiplier_exp(variance, -1, stddev_invsqrt_mul, stddev_invsqrt_shift);
+ for(uint32_t input_idx = 0; input_idx < num_input; ++input_idx)
+ {
+ const auto index = batch_idx * num_input + input_idx;
+ const auto val = static_cast<int32_t>(src[index]);
+ const auto shifted = (val << 10) - mean;
+ const auto rescaled = quantization::multiply_by_quantized_multiplier(shifted, stddev_invsqrt_mul, stddev_invsqrt_shift);
+ const int64_t weighted = rescaled * weight[input_idx] + bias[input_idx];
+ const auto reverse_shifted = static_cast<int32_t>((weighted + 512) >> 10);
+ auto out_val = quantization::multiply_by_quantized_multiplier(reverse_shifted, output_multiplier, output_shift + 12);
+ out_val = arm_compute::utility::clamp<decltype(out_val), int16_t>(out_val, std::numeric_limits<int16_t>::min());
+ output[index] = static_cast<int16_t>(out_val);
+ }
+ }
return output;
}
} // namespace reference