diff options
Diffstat (limited to 'tests/validation/fixtures/ElementwiseUnaryFixture.h')
-rw-r--r-- | tests/validation/fixtures/ElementwiseUnaryFixture.h | 246 |
1 files changed, 198 insertions, 48 deletions
diff --git a/tests/validation/fixtures/ElementwiseUnaryFixture.h b/tests/validation/fixtures/ElementwiseUnaryFixture.h index 7221226fd1..15344288db 100644 --- a/tests/validation/fixtures/ElementwiseUnaryFixture.h +++ b/tests/validation/fixtures/ElementwiseUnaryFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,8 +24,10 @@ #ifndef ARM_COMPUTE_TEST_ELEMENTWISE_UNARY_FIXTURE #define ARM_COMPUTE_TEST_ELEMENTWISE_UNARY_FIXTURE +#include "arm_compute/core/QuantizationInfo.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/Utils.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" #include "tests/IAccessor.h" @@ -33,6 +35,11 @@ #include "tests/framework/Fixture.h" #include "tests/validation/reference/ElementwiseUnary.h" +#include <tuple> +#include <limits> +#include <type_traits> +#include <vector> + namespace arm_compute { namespace test @@ -43,12 +50,12 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ElementWiseUnaryValidationFixture : public framework::Fixture { public: - template <typename...> - void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op, bool use_dynamic_shape = false) + void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op, + bool use_dynamic_shape = false, QuantizationInfo qinfo = QuantizationInfo(), QuantizationInfo qinfo_out = QuantizationInfo()) { _op = op; - _target = compute_target(input_shape, input_data_type, in_place); - _reference = compute_reference(input_shape, input_data_type); + _target = compute_target(input_shape, input_data_type, in_place, qinfo, qinfo_out); + _reference = compute_reference(input_shape, input_data_type, qinfo, qinfo_out); _use_dynamic_shape = use_dynamic_shape; } @@ -63,60 +70,131 @@ protected: { case ElementWiseUnary::EXP: { - FloatDistributionType distribution{ FloatType(-1.0f), FloatType(1.0f) }; - library->fill(tensor, distribution, i); + switch(data_type) + { + case DataType::F32: + { + FloatDistributionType distribution{ FloatType(-86.63f), FloatType(88.36f) }; + library->fill(tensor, distribution, i); + break; + } + + case DataType::F16: + { + FloatDistributionType distribution{ FloatType(-9.00f), FloatType(10.73f) }; + library->fill(tensor, distribution, i); + break; + } + + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + library->fill_tensor_uniform(tensor, i); + break; + + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + break; } case ElementWiseUnary::RSQRT: + case ElementWiseUnary::LOG: { - FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) }; - library->fill(tensor, distribution, i); - break; - } - case ElementWiseUnary::ABS: - case ElementWiseUnary::NEG: - { + // For floating-point data type, the chosen input range is all positive numbers + // (i.e. positive and negative zeros are excluded). switch(data_type) { + case DataType::F32: + { + FloatDistributionType distribution{ std::numeric_limits<float>::min(), std::numeric_limits<float>::max() }; + library->fill(tensor, distribution, i); + break; + } + case DataType::F16: { - arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -2.0f, 2.0f }; + FloatDistributionType distribution{ FloatType(0.00006103515625f), FloatType(65504.0f) }; library->fill(tensor, distribution, i); break; } + + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + library->fill_tensor_uniform(tensor, i); + break; + + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + + break; + } + case ElementWiseUnary::SIN: + { + switch(data_type) + { case DataType::F32: + case DataType::F16: { - FloatDistributionType distribution{ FloatType(-2.0f), FloatType(2.0f) }; + FloatDistributionType distribution{ FloatType(-100.0f), FloatType(100.0f) }; library->fill(tensor, distribution, i); break; } + case DataType::S32: { - std::uniform_int_distribution<int32_t> distribution(-100, 100); + std::uniform_int_distribution<int32_t> distribution(std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max()); library->fill(tensor, distribution, i); break; } + + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + library->fill_tensor_uniform(tensor, i); + break; + default: - ARM_COMPUTE_ERROR("DataType for Elementwise Negation Not implemented"); + ARM_COMPUTE_ERROR("Not implemented"); } + break; } - case ElementWiseUnary::LOG: - { - FloatDistributionType distribution{ FloatType(0.0000001f), FloatType(100.0f) }; - library->fill(tensor, distribution, i); - break; - } - case ElementWiseUnary::SIN: - { - FloatDistributionType distribution{ FloatType(-100.00f), FloatType(100.00f) }; - library->fill(tensor, distribution, i); - break; - } + case ElementWiseUnary::ABS: + case ElementWiseUnary::NEG: case ElementWiseUnary::ROUND: { - FloatDistributionType distribution{ FloatType(100.0f), FloatType(-100.0f) }; - library->fill(tensor, distribution, i); + switch(data_type) + { + case DataType::F32: + { + FloatDistributionType distribution{ std::numeric_limits<float>::lowest() / 2, std::numeric_limits<float>::max() / 2 }; + library->fill(tensor, distribution, i); + break; + } + + case DataType::F16: + { + FloatDistributionType distribution{ FloatType(-65504.0f), FloatType(65504.0f) }; + library->fill(tensor, distribution, i); + break; + } + + case DataType::S32: + { + std::uniform_int_distribution<int32_t> distribution(std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max()); + library->fill(tensor, distribution, i); + break; + } + + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + library->fill_tensor_uniform(tensor, i); + break; + + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + break; } default: @@ -124,12 +202,11 @@ protected: } } - TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place) + TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place, QuantizationInfo qinfo, QuantizationInfo qinfo_out) { // Create tensors - TensorType src = create_tensor<TensorType>(shape, data_type); - TensorType dst = create_tensor<TensorType>(shape, data_type); - + TensorType src = create_tensor<TensorType>(shape, data_type, 1, qinfo); + TensorType dst = create_tensor<TensorType>(shape, data_type, 1, qinfo_out); TensorType *actual_dst = in_place ? &src : &dst; // if _use_dynamic_shape is true, this fixture will test scenario for dynamic shapes. @@ -176,28 +253,39 @@ protected: } } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out) { // Create reference - SimpleTensor<T> src{ shape, data_type }; + SimpleTensor<T> src{ shape, data_type, 1, qinfo }; + SimpleTensor<T> dst{ shape, data_type, 1, qinfo_out }; // Fill reference fill(src, 0, data_type); - return reference::elementwise_unary<T>(src, _op); + return reference::elementwise_unary<T>(src, dst, _op); } TensorType _target{}; SimpleTensor<T> _reference{}; ElementWiseUnary _op{}; bool _use_dynamic_shape{ false }; + QuantizationInfo _input_qinfo{}; + QuantizationInfo _output_qinfo{}; +}; +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class RsqrtQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo, QuantizationInfo qinfo_out) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT, false, qinfo, qinfo_out); + } }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class RsqrtValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT); @@ -208,7 +296,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class RsqrtDynamicShapeValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT, true); @@ -219,7 +306,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ExpValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::EXP); @@ -227,10 +313,19 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class ExpQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::EXP, false, iq, oq); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class NegValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::NEG); @@ -238,10 +333,19 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class NegQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::NEG, false, iq, oq); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class NegValidationInPlaceFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type, bool in_place) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, in_place, ElementWiseUnary::NEG); @@ -249,10 +353,19 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class NegQuantizedValidationInPlaceFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(const TensorShape &shape, DataType data_type, bool in_place, QuantizationInfo iq, QuantizationInfo oq) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, in_place, ElementWiseUnary::NEG, false, iq, oq); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class LogValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::LOG); @@ -260,10 +373,19 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class LogQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::LOG, false, iq, oq); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class AbsValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ABS); @@ -271,10 +393,19 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class AbsQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ABS, false, iq, oq); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class SinValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::SIN); @@ -282,15 +413,34 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class SinQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::SIN, false, iq, oq); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class RoundValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(const TensorShape &shape, DataType data_type) { ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ROUND); } }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class RoundQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ROUND, false, iq, oq); + } +}; } // namespace validation } // namespace test } // namespace arm_compute |