diff options
author | Murray Kornelsen <murray.kornelsen@mail.mcgill.ca> | 2022-07-13 21:40:26 -0400 |
---|---|---|
committer | Pablo Marquez Tello <pablo.tello@arm.com> | 2022-09-14 06:48:39 +0000 |
commit | 6e09e1404c635d948cf20eb6b4b5747dfb6656f2 (patch) | |
tree | 006199bd21b8a1330e1f1c86be60084bfb466706 /tests/validation/fixtures | |
parent | a4814e8394ffdd7e268614d54cc22e30648f48ff (diff) | |
download | ComputeLibrary-6e09e1404c635d948cf20eb6b4b5747dfb6656f2.tar.gz |
INT8 Quantized MeanStdDevNorm (LayerNorm)
Implements LayerNorm for qasymm8 tensors.
Uses uint8x16 loads and stores.
Summation is performed in integer arithmetic (vpaddl)
Normalization is performed in float32 before requantizing back to int8.
Signed-off-by: Murray Kornelsen <murray.kornelsen@mail.mcgill.ca>
Change-Id: I2407c8b34717fb47adab98791bd76fb8a3c62f4a
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7922
Comments-Addressed: Pablo Marquez Tello <pablo.tello@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h | 39 |
1 files changed, 23 insertions, 16 deletions
diff --git a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h index 9868cd1abf..f3c108e6da 100644 --- a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h +++ b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -45,29 +45,35 @@ class MeanStdDevNormalizationLayerValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType dt, bool in_place, float epsilon = 1e-8f) + void setup(TensorShape shape, DataType dt, bool in_place, float epsilon = 1e-8) { - _data_type = dt; - _target = compute_target(shape, dt, in_place, epsilon); - _reference = compute_reference(shape, dt, epsilon); + QuantizationInfo qi = QuantizationInfo(0.5f, 10); + _data_type = dt; + _target = compute_target(shape, dt, in_place, epsilon, qi); + _reference = compute_reference(shape, dt, epsilon, qi); } protected: template <typename U> - void fill(U &&src_tensor) + void fill(U &&tensor) { - static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported."); - using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; - - DistributionType distribution{ T(-1.0f), T(1.0f) }; - library->fill(src_tensor, distribution, 0); + if(is_data_type_float(_data_type)) + { + std::uniform_real_distribution<> distribution{ -1.0f, 1.0f }; + library->fill(tensor, distribution, 0); + } + else + { + std::uniform_int_distribution<> distribution{ 0, 255 }; + library->fill(tensor, distribution, 0); + } } - TensorType compute_target(TensorShape shape, DataType dt, bool in_place, float epsilon) + TensorType compute_target(TensorShape shape, DataType dt, bool in_place, float epsilon, QuantizationInfo qi) { // Create tensors - TensorType src = create_tensor<TensorType>(shape, dt, 1); - TensorType dst; + TensorType src = create_tensor<TensorType>(shape, dt, 1, qi); + TensorType dst = create_tensor<TensorType>(shape, dt, 1, qi); TensorType *dst_ptr = in_place ? &src : &dst; @@ -104,10 +110,10 @@ protected: } } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType dt, float epsilon) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType dt, float epsilon, QuantizationInfo qi) { // Create reference - SimpleTensor<T> ref_src{ shape, dt, 1 }; + SimpleTensor<T> ref_src{ shape, dt, 1, qi }; // Fill reference fill(ref_src); @@ -119,6 +125,7 @@ protected: SimpleTensor<T> _reference{}; DataType _data_type{}; }; + } // namespace validation } // namespace test } // namespace arm_compute |