aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ElementwiseUnaryFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ElementwiseUnaryFixture.h')
-rw-r--r--tests/validation/fixtures/ElementwiseUnaryFixture.h212
1 files changed, 181 insertions, 31 deletions
diff --git a/tests/validation/fixtures/ElementwiseUnaryFixture.h b/tests/validation/fixtures/ElementwiseUnaryFixture.h
index 1dc4f03e99..9b40d34d2b 100644
--- a/tests/validation/fixtures/ElementwiseUnaryFixture.h
+++ b/tests/validation/fixtures/ElementwiseUnaryFixture.h
@@ -24,8 +24,10 @@
#ifndef ARM_COMPUTE_TEST_ELEMENTWISE_UNARY_FIXTURE
#define ARM_COMPUTE_TEST_ELEMENTWISE_UNARY_FIXTURE
+#include "arm_compute/core/QuantizationInfo.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
#include "tests/IAccessor.h"
@@ -33,6 +35,11 @@
#include "tests/framework/Fixture.h"
#include "tests/validation/reference/ElementwiseUnary.h"
+#include <tuple>
+#include <limits>
+#include <type_traits>
+#include <vector>
+
namespace arm_compute
{
namespace test
@@ -64,67 +71,131 @@ protected:
{
case ElementWiseUnary::EXP:
{
- FloatDistributionType distribution{ FloatType(-1.0f), FloatType(1.0f) };
- library->fill(tensor, distribution, i);
+ switch(data_type)
+ {
+ case DataType::F32:
+ {
+ FloatDistributionType distribution{ FloatType(-86.63f), FloatType(88.36f) };
+ library->fill(tensor, distribution, i);
+ break;
+ }
+
+ case DataType::F16:
+ {
+ FloatDistributionType distribution{ FloatType(-9.00f), FloatType(10.73f) };
+ library->fill(tensor, distribution, i);
+ break;
+ }
+
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ library->fill_tensor_uniform(tensor, i);
+ break;
+
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+
break;
}
case ElementWiseUnary::RSQRT:
+ case ElementWiseUnary::LOG:
{
- if(data_type == DataType::F32 || data_type == DataType::F16)
+ // For floating-point data type, the chosen input range is all positive numbers
+ // (i.e. positive and negative zeros are excluded).
+ switch(data_type)
{
- FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
- library->fill(tensor, distribution, i);
+ case DataType::F32:
+ {
+ FloatDistributionType distribution{ std::numeric_limits<float>::min(), std::numeric_limits<float>::max() };
+ library->fill(tensor, distribution, i);
+ break;
+ }
+
+ case DataType::F16:
+ {
+ FloatDistributionType distribution{ FloatType(0.00006103515625f), FloatType(65504.0f) };
+ library->fill(tensor, distribution, i);
+ break;
+ }
+
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ library->fill_tensor_uniform(tensor, i);
+ break;
+
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
}
- else
+
+ break;
+ }
+ case ElementWiseUnary::SIN:
+ {
+ switch(data_type)
{
- library->fill_tensor_uniform(tensor, i);
+ case DataType::F32:
+ case DataType::F16:
+ {
+ FloatDistributionType distribution{ FloatType(-100.0f), FloatType(100.0f) };
+ library->fill(tensor, distribution, i);
+ break;
+ }
+
+ case DataType::S32:
+ {
+ std::uniform_int_distribution<int32_t> distribution(std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max());
+ library->fill(tensor, distribution, i);
+ break;
+ }
+
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ library->fill_tensor_uniform(tensor, i);
+ break;
+
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
}
+
break;
}
case ElementWiseUnary::ABS:
case ElementWiseUnary::NEG:
+ case ElementWiseUnary::ROUND:
{
switch(data_type)
{
- case DataType::F16:
+ case DataType::F32:
{
- arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -2.0f, 2.0f };
+ FloatDistributionType distribution{ std::numeric_limits<float>::lowest() / 2, std::numeric_limits<float>::max() / 2 };
library->fill(tensor, distribution, i);
break;
}
- case DataType::F32:
+
+ case DataType::F16:
{
- FloatDistributionType distribution{ FloatType(-2.0f), FloatType(2.0f) };
+ FloatDistributionType distribution{ FloatType(-65504.0f), FloatType(65504.0f) };
library->fill(tensor, distribution, i);
break;
}
+
case DataType::S32:
{
- std::uniform_int_distribution<int32_t> distribution(-100, 100);
+ std::uniform_int_distribution<int32_t> distribution(std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max());
library->fill(tensor, distribution, i);
break;
}
+
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ library->fill_tensor_uniform(tensor, i);
+ break;
+
default:
- ARM_COMPUTE_ERROR("DataType for Elementwise Negation Not implemented");
+ ARM_COMPUTE_ERROR("Not implemented");
}
- break;
- }
- case ElementWiseUnary::LOG:
- {
- FloatDistributionType distribution{ FloatType(0.0000001f), FloatType(100.0f) };
- library->fill(tensor, distribution, i);
- break;
- }
- case ElementWiseUnary::SIN:
- {
- FloatDistributionType distribution{ FloatType(-100.00f), FloatType(100.00f) };
- library->fill(tensor, distribution, i);
- break;
- }
- case ElementWiseUnary::ROUND:
- {
- FloatDistributionType distribution{ FloatType(100.0f), FloatType(-100.0f) };
- library->fill(tensor, distribution, i);
+
break;
}
default:
@@ -199,6 +270,8 @@ protected:
SimpleTensor<T> _reference{};
ElementWiseUnary _op{};
bool _use_dynamic_shape{ false };
+ QuantizationInfo _input_qinfo{};
+ QuantizationInfo _output_qinfo{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class RsqrtQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
@@ -245,6 +318,17 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ExpQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::EXP, false, iq, oq);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class NegValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
@@ -256,6 +340,17 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class NegQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::NEG, false, iq, oq);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class NegValidationInPlaceFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
@@ -267,6 +362,17 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class NegQuantizedValidationInPlaceFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, bool in_place, QuantizationInfo iq, QuantizationInfo oq)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, in_place, ElementWiseUnary::NEG, false, iq, oq);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class LogValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
@@ -278,6 +384,17 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class LogQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::LOG, false, iq, oq);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class AbsValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
@@ -289,6 +406,17 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class AbsQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ABS, false, iq, oq);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class SinValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
@@ -300,6 +428,17 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class SinQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::SIN, false, iq, oq);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class RoundValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
@@ -309,6 +448,17 @@ public:
ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ROUND);
}
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class RoundQuantizedValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType data_type, QuantizationInfo iq, QuantizationInfo oq)
+ {
+ ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ROUND, false, iq, oq);
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute