diff options
Diffstat (limited to 'tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h | 47 |
1 files changed, 27 insertions, 20 deletions
diff --git a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h index 1c48b74baf..bf5d20790c 100644 --- a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h +++ b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -44,29 +44,35 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class MeanStdDevNormalizationLayerValidationFixture : public framework::Fixture { public: - template <typename...> - void setup(TensorShape shape, DataType dt, bool in_place, float epsilon = 1e-8f) + void setup(TensorShape shape, DataType dt, bool in_place, float epsilon = 1e-8) { - _data_type = dt; - _target = compute_target(shape, dt, in_place, epsilon); - _reference = compute_reference(shape, dt, epsilon); + QuantizationInfo qi = QuantizationInfo(0.5f, 10); + _data_type = dt; + _target = compute_target(shape, dt, in_place, epsilon, qi); + _reference = compute_reference(shape, dt, epsilon, qi); } protected: template <typename U> - void fill(U &&src_tensor) + void fill(U &&tensor) { - const float min_bound = -1.f; - const float max_bound = 1.f; - std::uniform_real_distribution<> distribution(min_bound, max_bound); - library->fill(src_tensor, distribution, 0); + if(is_data_type_float(_data_type)) + { + std::uniform_real_distribution<> distribution{ -1.0f, 1.0f }; + library->fill(tensor, distribution, 0); + } + else + { + std::uniform_int_distribution<> distribution{ 0, 255 }; + library->fill(tensor, distribution, 0); + } } - TensorType compute_target(TensorShape shape, DataType dt, bool in_place, float epsilon) + TensorType compute_target(TensorShape shape, DataType dt, bool in_place, float epsilon, QuantizationInfo qi) { // Create tensors - TensorType src = create_tensor<TensorType>(shape, dt, 1); - TensorType dst; + TensorType src = create_tensor<TensorType>(shape, dt, 1, qi); + TensorType dst = create_tensor<TensorType>(shape, dt, 1, qi); TensorType *dst_ptr = in_place ? &src : &dst; @@ -74,17 +80,17 @@ protected: FunctionType norm; norm.configure(&src, dst_ptr, epsilon); - ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); // Allocate tensors src.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); if(!in_place) { dst.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); } // Fill tensors @@ -103,10 +109,10 @@ protected: } } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType dt, float epsilon) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType dt, float epsilon, QuantizationInfo qi) { // Create reference - SimpleTensor<T> ref_src{ shape, dt, 1 }; + SimpleTensor<T> ref_src{ shape, dt, 1, qi }; // Fill reference fill(ref_src); @@ -118,6 +124,7 @@ protected: SimpleTensor<T> _reference{}; DataType _data_type{}; }; + } // namespace validation } // namespace test } // namespace arm_compute |