aboutsummaryrefslogtreecommitdiff
path: root/tests/validation_new/Helpers.h
diff options
context:
space:
mode:
authorMoritz Pflanzer <moritz.pflanzer@arm.com>2017-07-21 17:36:33 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commit572ade736ab344a62afa7da214cd9407fe53a281 (patch)
treeadc0b31c0e236b65822dcbc9fb45ce401cc6ead4 /tests/validation_new/Helpers.h
parent8e6faf1e9f1af7a03441612c30644776e87fd235 (diff)
downloadComputeLibrary-572ade736ab344a62afa7da214cd9407fe53a281.tar.gz
COMPMID-415: Move ActivationLayer to new validation
Change-Id: I38ce20d95640f9c1baf699a095c35e592ad4339f Reviewed-on: http://mpd-gerrit.cambridge.arm.com/81115 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'tests/validation_new/Helpers.h')
-rw-r--r--tests/validation_new/Helpers.h85
1 files changed, 85 insertions, 0 deletions
diff --git a/tests/validation_new/Helpers.h b/tests/validation_new/Helpers.h
index e25b684c11..3058b8eaee 100644
--- a/tests/validation_new/Helpers.h
+++ b/tests/validation_new/Helpers.h
@@ -24,9 +24,13 @@
#ifndef __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
#define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
#include "tests/validation/half.h"
+#include <random>
#include <type_traits>
+#include <utility>
namespace arm_compute
{
@@ -43,6 +47,87 @@ template <>
struct is_floating_point<half_float::half> : public std::true_type
{
};
+
+/** Helper function to get the testing range for each activation layer.
+ *
+ * @param[in] activation Activation function to test.
+ * @param[in] data_type Data type.
+ * @param[in] fixed_point_position Number of bits for the fractional part. Defaults to 1.
+ *
+ * @return A pair containing the lower upper testing bounds for a given function.
+ */
+template <typename T>
+std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::ActivationFunction activation, DataType data_type, int fixed_point_position = 0)
+{
+ std::pair<T, T> bounds;
+
+ switch(data_type)
+ {
+ case DataType::F16:
+ {
+ using namespace half_float::literal;
+
+ switch(activation)
+ {
+ case ActivationLayerInfo::ActivationFunction::SQUARE:
+ case ActivationLayerInfo::ActivationFunction::LOGISTIC:
+ case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
+ // Reduce range as exponent overflows
+ bounds = std::make_pair(-10._h, 10._h);
+ break;
+ case ActivationLayerInfo::ActivationFunction::SQRT:
+ // Reduce range as sqrt should take a non-negative number
+ bounds = std::make_pair(0._h, 255._h);
+ break;
+ default:
+ bounds = std::make_pair(-255._h, 255._h);
+ break;
+ }
+ break;
+ }
+ case DataType::F32:
+ switch(activation)
+ {
+ case ActivationLayerInfo::ActivationFunction::LOGISTIC:
+ case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
+ // Reduce range as exponent overflows
+ bounds = std::make_pair(-40.f, 40.f);
+ break;
+ case ActivationLayerInfo::ActivationFunction::SQRT:
+ // Reduce range as sqrt should take a non-negative number
+ bounds = std::make_pair(0.f, 255.f);
+ break;
+ default:
+ bounds = std::make_pair(-255.f, 255.f);
+ break;
+ }
+ break;
+ case DataType::QS8:
+ case DataType::QS16:
+ switch(activation)
+ {
+ case ActivationLayerInfo::ActivationFunction::LOGISTIC:
+ case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
+ case ActivationLayerInfo::ActivationFunction::TANH:
+ // Reduce range as exponent overflows
+ bounds = std::make_pair(-(1 << fixed_point_position), 1 << fixed_point_position);
+ break;
+ case ActivationLayerInfo::ActivationFunction::SQRT:
+ // Reduce range as sqrt should take a non-negative number
+ // Can't be zero either as inv_sqrt is used in NEON.
+ bounds = std::make_pair(1, std::numeric_limits<T>::max());
+ break;
+ default:
+ bounds = std::make_pair(std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
+ break;
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
+
+ return bounds;
+}
} // namespace validation
} // namespace test
} // namespace arm_compute