diff options
Diffstat (limited to 'tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h | 39 |
1 files changed, 23 insertions, 16 deletions
diff --git a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h index 9868cd1abf..f3c108e6da 100644 --- a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h +++ b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -45,29 +45,35 @@ class MeanStdDevNormalizationLayerValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType dt, bool in_place, float epsilon = 1e-8f) + void setup(TensorShape shape, DataType dt, bool in_place, float epsilon = 1e-8) { - _data_type = dt; - _target = compute_target(shape, dt, in_place, epsilon); - _reference = compute_reference(shape, dt, epsilon); + QuantizationInfo qi = QuantizationInfo(0.5f, 10); + _data_type = dt; + _target = compute_target(shape, dt, in_place, epsilon, qi); + _reference = compute_reference(shape, dt, epsilon, qi); } protected: template <typename U> - void fill(U &&src_tensor) + void fill(U &&tensor) { - static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported."); - using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; - - DistributionType distribution{ T(-1.0f), T(1.0f) }; - library->fill(src_tensor, distribution, 0); + if(is_data_type_float(_data_type)) + { + std::uniform_real_distribution<> distribution{ -1.0f, 1.0f }; + library->fill(tensor, distribution, 0); + } + else + { + std::uniform_int_distribution<> distribution{ 0, 255 }; + library->fill(tensor, distribution, 0); + } } - TensorType compute_target(TensorShape shape, DataType dt, bool in_place, float epsilon) + TensorType compute_target(TensorShape shape, DataType dt, bool in_place, float epsilon, QuantizationInfo qi) { // Create tensors - TensorType src = create_tensor<TensorType>(shape, dt, 1); - TensorType dst; + TensorType src = create_tensor<TensorType>(shape, dt, 1, qi); + TensorType dst = create_tensor<TensorType>(shape, dt, 1, qi); TensorType *dst_ptr = in_place ? &src : &dst; @@ -104,10 +110,10 @@ protected: } } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType dt, float epsilon) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType dt, float epsilon, QuantizationInfo qi) { // Create reference - SimpleTensor<T> ref_src{ shape, dt, 1 }; + SimpleTensor<T> ref_src{ shape, dt, 1, qi }; // Fill reference fill(ref_src); @@ -119,6 +125,7 @@ protected: SimpleTensor<T> _reference{}; DataType _data_type{}; }; + } // namespace validation } // namespace test } // namespace arm_compute |