aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/NEON/NEMath.h18
-rw-r--r--arm_compute/core/NEON/NEMath.inl23
-rw-r--r--arm_compute/core/NEON/kernels/NEActivationLayerKernel.h12
-rw-r--r--src/core/NEON/kernels/NEActivationLayerKernel.cpp151
-rw-r--r--tests/benchmark_new/NEON/ActivationLayer.cpp20
-rw-r--r--tests/validation/Helpers.h13
-rw-r--r--tests/validation/NEON/ActivationLayer.cpp73
-rw-r--r--tests/validation/Reference.cpp51
-rw-r--r--tests/validation/TensorOperations.h4
9 files changed, 316 insertions, 49 deletions
diff --git a/arm_compute/core/NEON/NEMath.h b/arm_compute/core/NEON/NEMath.h
index 8dd9d609e7..b467a600d6 100644
--- a/arm_compute/core/NEON/NEMath.h
+++ b/arm_compute/core/NEON/NEMath.h
@@ -93,6 +93,24 @@ float32x4_t vtanhq_f32(float32x4_t val);
float32x4_t vpowq_f32(float32x4_t val, float32x4_t n);
#ifdef ARM_COMPUTE_ENABLE_FP16
+/** Calculate hyperbolic tangent.
+ *
+ * tanh(x) = (e^2x - 1)/(e^2x + 1)
+ *
+ * @note We clamp x to [-5,5] to avoid overflowing issues.
+ *
+ * @param[in] val Input vector value in F32 format.
+ *
+ * @return The calculated Hyperbolic Tangent.
+ */
+float16x8_t vtanhq_f16(float16x8_t val);
+/** Calculate inverse square root.
+ *
+ * @param[in] x Input value.
+ *
+ * @return The calculated inverse square root.
+ */
+float16x8_t vinvsqrtq_f16(float16x8_t x);
/** Calculate exponential
*
* @param[in] x Input vector value in F16 format.
diff --git a/arm_compute/core/NEON/NEMath.inl b/arm_compute/core/NEON/NEMath.inl
index c73c54501f..1d90029147 100644
--- a/arm_compute/core/NEON/NEMath.inl
+++ b/arm_compute/core/NEON/NEMath.inl
@@ -172,6 +172,14 @@ const std::array<float16x8_t, 8> log_tab_f16 =
vdupq_n_f16(0.0141278216615f),
}
};
+inline float16x8_t vinvsqrtq_f16(float16x8_t x)
+{
+ float16x8_t sqrt_reciprocal = vrsqrteq_f16(x);
+ sqrt_reciprocal = vmulq_f16(vrsqrtsq_f16(vmulq_f16(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal);
+ sqrt_reciprocal = vmulq_f16(vrsqrtsq_f16(vmulq_f16(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal);
+
+ return sqrt_reciprocal;
+}
inline float16x8_t vinvq_f16(float16x8_t x)
{
@@ -181,6 +189,21 @@ inline float16x8_t vinvq_f16(float16x8_t x)
return recip;
}
+inline float16x8_t vtanhq_f16(float16x8_t val)
+{
+ const float16x8_t CONST_1 = vdupq_n_f16(1.f);
+ const float16x8_t CONST_2 = vdupq_n_f16(2.f);
+ const float16x8_t CONST_MIN_TANH = vdupq_n_f16(-10.f);
+ const float16x8_t CONST_MAX_TANH = vdupq_n_f16(10.f);
+
+ const float16x8_t x = vminq_f16(vmaxq_f16(val, CONST_MIN_TANH), CONST_MAX_TANH);
+ const float16x8_t exp2x = vexpq_f16(vmulq_f16(CONST_2, x));
+ const float16x8_t num = vsubq_f16(exp2x, CONST_1);
+ const float16x8_t den = vaddq_f16(exp2x, CONST_1);
+ const float16x8_t tanh = vmulq_f16(num, vinvq_f16(den));
+ return tanh;
+}
+
inline float16x8_t vtaylor_polyq_f16(float16x8_t x, const std::array<float16x8_t, 8> &coeffs)
{
const float16x8_t A = vaddq_f16(coeffs[0], vmulq_f16(coeffs[4], x));
diff --git a/arm_compute/core/NEON/kernels/NEActivationLayerKernel.h b/arm_compute/core/NEON/kernels/NEActivationLayerKernel.h
index e995f1e5e0..2c88debfb4 100644
--- a/arm_compute/core/NEON/kernels/NEActivationLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEActivationLayerKernel.h
@@ -27,6 +27,10 @@
#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/NEON/INEKernel.h"
+#ifdef ARM_COMPUTE_ENABLE_FP16
+#include <arm_fp16.h>
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
namespace arm_compute
{
class ITensor;
@@ -72,6 +76,14 @@ private:
*/
template <ActivationLayerInfo::ActivationFunction F, typename T>
typename std::enable_if<std::is_same<T, float>::value, void>::type activation(const Window &window);
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ /** Function to apply an activation function on a tensor.
+ *
+ * @param[in] window Region on which to execute the kernel
+ */
+ template <ActivationLayerInfo::ActivationFunction F, typename T>
+ typename std::enable_if<std::is_same<T, float16_t>::value, void>::type activation(const Window &window);
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
/** Function to apply an activation function on a tensor.
*
* @param[in] window Region on which to execute the kernel
diff --git a/src/core/NEON/kernels/NEActivationLayerKernel.cpp b/src/core/NEON/kernels/NEActivationLayerKernel.cpp
index 70b7057fcd..3195411e18 100644
--- a/src/core/NEON/kernels/NEActivationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEActivationLayerKernel.cpp
@@ -47,7 +47,7 @@ NEActivationLayerKernel::NEActivationLayerKernel()
void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, ActivationLayerInfo activation_info)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
_input = input;
_act_info = activation_info;
@@ -79,6 +79,23 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat
{ ActivationFunction::SQUARE, &NEActivationLayerKernel::activation<ActivationFunction::SQUARE, float> },
{ ActivationFunction::TANH, &NEActivationLayerKernel::activation<ActivationFunction::TANH, float> },
};
+
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ // Activation functions : FP16
+ static std::map<ActivationFunction, ActivationFunctionExecutorPtr> act_map_f16 =
+ {
+ { ActivationFunction::ABS, &NEActivationLayerKernel::activation<ActivationFunction::ABS, float16_t> },
+ { ActivationFunction::LINEAR, &NEActivationLayerKernel::activation<ActivationFunction::LINEAR, float16_t> },
+ { ActivationFunction::LOGISTIC, &NEActivationLayerKernel::activation<ActivationFunction::LOGISTIC, float16_t> },
+ { ActivationFunction::RELU, &NEActivationLayerKernel::activation<ActivationFunction::RELU, float16_t> },
+ { ActivationFunction::BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::BOUNDED_RELU, float16_t> },
+ { ActivationFunction::SOFT_RELU, &NEActivationLayerKernel::activation<ActivationFunction::SOFT_RELU, float16_t> },
+ { ActivationFunction::SQRT, &NEActivationLayerKernel::activation<ActivationFunction::SQRT, float16_t> },
+ { ActivationFunction::SQUARE, &NEActivationLayerKernel::activation<ActivationFunction::SQUARE, float16_t> },
+ { ActivationFunction::TANH, &NEActivationLayerKernel::activation<ActivationFunction::TANH, float16_t> },
+ };
+#endif /* ARM_COMPUTE_ENABLE_FP16*/
+
// Activation functions : QS8
static std::map<ActivationFunction, ActivationFunctionExecutorPtr> act_map_qs8 =
{
@@ -119,6 +136,11 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat
case DataType::F32:
_func = act_map_f32[activation_info.activation()];
break;
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ case DataType::F16:
+ _func = act_map_f16[activation_info.activation()];
+ break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
default:
ARM_COMPUTE_ERROR("Unsupported data type.");
}
@@ -148,6 +170,130 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat
ICPPKernel::configure(win);
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
+template <ActivationLayerInfo::ActivationFunction F, typename T>
+typename std::enable_if<std::is_same<T, float16_t>::value, void>::type NEActivationLayerKernel::activation(const Window &window)
+{
+ Iterator input(_input, window);
+ Iterator output(_output, window);
+
+ static const float16x8_t CONST_0 = vdupq_n_f16(0.f);
+ static const float16x8_t CONST_1 = vdupq_n_f16(1.f);
+
+ const float16x8_t a = vdupq_n_f16(_act_info.a());
+ const float16x8_t b = vdupq_n_f16(_act_info.b());
+
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<const float16_t *>(input.ptr());
+ const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
+
+ const float16x8x2_t in = vld2q_f16(input_ptr);
+ float16x8x2_t tmp = { {} };
+
+ switch(F)
+ {
+ case ActivationFunction::ABS:
+ tmp =
+ {
+ {
+ vabsq_f16(in.val[0]),
+ vabsq_f16(in.val[1]),
+ }
+ };
+ break;
+ case ActivationFunction::BOUNDED_RELU:
+ tmp =
+ {
+ {
+ vminq_f16(a, vmaxq_f16(CONST_0, in.val[0])),
+ vminq_f16(a, vmaxq_f16(CONST_0, in.val[1]))
+ }
+ };
+ break;
+ case ActivationFunction::LINEAR:
+ tmp =
+ {
+ {
+ vaddq_f16(b, vmulq_f16(a, in.val[0])),
+ vaddq_f16(b, vmulq_f16(a, in.val[1]))
+ }
+ };
+ break;
+ case ActivationFunction::LOGISTIC:
+ tmp =
+ {
+ {
+ vinvq_f16(vaddq_f16(CONST_1, vexpq_f16(vnegq_f16(in.val[0])))),
+ vinvq_f16(vaddq_f16(CONST_1, vexpq_f16(vnegq_f16(in.val[1])))),
+ }
+ };
+ break;
+ case ActivationFunction::RELU:
+ tmp =
+ {
+ {
+ vmaxq_f16(CONST_0, in.val[0]),
+ vmaxq_f16(CONST_0, in.val[1])
+ }
+ };
+ break;
+ case ActivationFunction::LEAKY_RELU:
+ tmp =
+ {
+ {
+ vbslq_f16(vcgtq_f16(in.val[0], CONST_0), in.val[0], vmulq_f16(a, in.val[0])),
+ vbslq_f16(vcgtq_f16(in.val[1], CONST_0), in.val[1], vmulq_f16(a, in.val[1]))
+ }
+ };
+ break;
+ case ActivationFunction::SOFT_RELU:
+ tmp =
+ {
+ {
+ vlogq_f16(vaddq_f16(CONST_1, vexpq_f16(in.val[0]))),
+ vlogq_f16(vaddq_f16(CONST_1, vexpq_f16(in.val[1]))),
+ }
+ };
+ break;
+ case ActivationFunction::SQRT:
+ tmp =
+ {
+ {
+ vinvq_f16(vinvsqrtq_f16(in.val[0])),
+ vinvq_f16(vinvsqrtq_f16(in.val[1])),
+ }
+ };
+ break;
+ case ActivationFunction::SQUARE:
+ tmp =
+ {
+ {
+ vmulq_f16(in.val[0], in.val[0]),
+ vmulq_f16(in.val[1], in.val[1])
+ }
+ };
+ break;
+ case ActivationFunction::TANH:
+ tmp =
+ {
+ {
+ vmulq_f16(a, vtanhq_f16(vmulq_f16(b, in.val[0]))),
+ vmulq_f16(a, vtanhq_f16(vmulq_f16(b, in.val[1]))),
+ }
+ };
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ break;
+ }
+
+ vst2q_f16(output_ptr, tmp);
+ },
+ input, output);
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
template <ActivationLayerInfo::ActivationFunction F, typename T>
typename std::enable_if<std::is_same<T, float>::value, void>::type NEActivationLayerKernel::activation(const Window &window)
{
@@ -350,7 +496,7 @@ typename std::enable_if<std::is_same<T, int8_t>::value, void>::type NEActivation
}
template <ActivationLayerInfo::ActivationFunction F, typename T>
-typename std::enable_if<std::is_same<T, int16_t>::value, void>::type NEActivationLayerKernel::activation(const Window &window)
+typename std::enable_if<std::is_same<T, qint16_t>::value, void>::type NEActivationLayerKernel::activation(const Window &window)
{
Iterator input(_input, window);
Iterator output(_output, window);
@@ -462,6 +608,7 @@ typename std::enable_if<std::is_same<T, int16_t>::value, void>::type NEActivatio
};
break;
default:
+ ARM_COMPUTE_ERROR("Function not implemented");
break;
}
diff --git a/tests/benchmark_new/NEON/ActivationLayer.cpp b/tests/benchmark_new/NEON/ActivationLayer.cpp
index 47838f43ba..21e4369aa2 100644
--- a/tests/benchmark_new/NEON/ActivationLayer.cpp
+++ b/tests/benchmark_new/NEON/ActivationLayer.cpp
@@ -37,23 +37,31 @@ namespace arm_compute
{
namespace test
{
+namespace
+{
+#ifdef ARM_COMPUTE_ENABLE_FP16
+const auto alexnet_data_types = framework::dataset::make("DataType", { DataType::QS8, DataType::F16, DataType::F32 });
+const auto lenet_data_types = framework::dataset::make("DataType", { DataType::F16, DataType::F32 });
+#else /* ARM_COMPUTE_ENABLE_FP16 */
+const auto alexnet_data_types = framework::dataset::make("DataType", { DataType::QS8, DataType::F32 });
+const auto lenet_data_types = framework::dataset::make("DataType", { DataType::F32 });
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+} // namespace
+
using NEActivationLayerFixture = ActivationLayerFixture<Tensor, NEActivationLayer, neon::NEAccessor>;
TEST_SUITE(NEON)
REGISTER_FIXTURE_DATA_TEST_CASE(AlexNetActivationLayer, NEActivationLayerFixture, framework::DatasetMode::ALL,
- framework::dataset::combine(framework::dataset::combine(datasets::AlexNetActivationLayerDataset(),
- framework::dataset::make("DataType", { DataType::F32, DataType::QS8 })),
+ framework::dataset::combine(framework::dataset::combine(datasets::AlexNetActivationLayerDataset(), alexnet_data_types),
framework::dataset::make("Batches", { 1, 4, 8 })));
REGISTER_FIXTURE_DATA_TEST_CASE(LeNet5ActivationLayer, NEActivationLayerFixture, framework::DatasetMode::ALL,
- framework::dataset::combine(framework::dataset::combine(datasets::LeNet5ActivationLayerDataset(),
- framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::combine(framework::dataset::combine(datasets::LeNet5ActivationLayerDataset(), lenet_data_types),
framework::dataset::make("Batches", { 1, 4, 8 })));
REGISTER_FIXTURE_DATA_TEST_CASE(GoogLeNetActivationLayer, NEActivationLayerFixture, framework::DatasetMode::ALL,
- framework::dataset::combine(framework::dataset::combine(datasets::GoogLeNetActivationLayerDataset(),
- framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::combine(framework::dataset::combine(datasets::GoogLeNetActivationLayerDataset(), lenet_data_types),
framework::dataset::make("Batches", { 1, 4, 8 })));
TEST_SUITE_END()
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h
index 4ee2112bcc..191e32813c 100644
--- a/tests/validation/Helpers.h
+++ b/tests/validation/Helpers.h
@@ -35,6 +35,10 @@
#include <utility>
#include <vector>
+#ifdef ARM_COMPUTE_ENABLE_FP16
+#include <arm_fp16.h>
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
namespace arm_compute
{
namespace test
@@ -49,9 +53,13 @@ namespace validation
* @return A pair containing the lower upper testing bounds for a given function.
*/
template <typename T>
-std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::ActivationFunction activation, int fixed_point_position = 1)
+inline std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::ActivationFunction activation, int fixed_point_position = 1)
{
- bool is_float = std::is_floating_point<T>::value;
+ bool is_float = std::is_same<T, float>::value;
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ is_float = is_float || std::is_same<T, float16_t>::value;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
std::pair<T, T> bounds;
// Set initial values
@@ -98,7 +106,6 @@ std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::Activation
}
return bounds;
}
-
/** Helper function to get the testing range for batch normalization layer.
*
* @param[in] fixed_point_position (Optional) Number of bits for the fractional part. Defaults to 1.
diff --git a/tests/validation/NEON/ActivationLayer.cpp b/tests/validation/NEON/ActivationLayer.cpp
index 2b24fd5175..b8827a5324 100644
--- a/tests/validation/NEON/ActivationLayer.cpp
+++ b/tests/validation/NEON/ActivationLayer.cpp
@@ -73,6 +73,8 @@ float activation_layer_tolerance(DataType dt, ActivationLayerInfo::ActivationFun
return 5.f;
case DataType::QS16:
return 11.f;
+ case DataType::F16:
+ return 0.01f;
default:
return 0.00001f;
}
@@ -119,30 +121,44 @@ Tensor compute_activation_layer(bool in_place, const TensorShape &shape, DataTyp
dst.allocator()->allocate();
BOOST_TEST(!dst.info()->is_resizable());
}
-
// Fill tensors
- if(dt == DataType::F32)
- {
- float min_bound = 0;
- float max_bound = 0;
- std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<float>(act_info.activation());
- std::uniform_real_distribution<> distribution(min_bound, max_bound);
- library->fill(NEAccessor(src), distribution, 0);
- }
- else
+ switch(dt)
{
- int min_bound = 0;
- int max_bound = 0;
- if(dt == DataType::QS8)
+ case DataType::QS8:
+ {
+ const std::pair<int8_t, int8_t> bounds = get_activation_layer_test_bounds<int8_t>(act_info.activation(), fixed_point_position);
+ std::uniform_int_distribution<> distribution(bounds.first, bounds.second);
+ library->fill(NEAccessor(src), distribution, 0);
+ break;
+ }
+ case DataType::QS16:
+ {
+ const std::pair<int16_t, int16_t> bounds = get_activation_layer_test_bounds<int16_t>(act_info.activation(), fixed_point_position);
+ std::uniform_int_distribution<> distribution(bounds.first, bounds.second);
+ library->fill(NEAccessor(src), distribution, 0);
+ break;
+ }
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ case DataType::F16:
+ {
+ const std::pair<float16_t, float16_t> bounds = get_activation_layer_test_bounds<float16_t>(act_info.activation());
+ std::uniform_real_distribution<> distribution(bounds.first, bounds.second);
+ library->fill(NEAccessor(src), distribution, 0);
+ break;
+ }
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+ case DataType::F32:
{
- std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<int8_t>(act_info.activation(), fixed_point_position);
+ const std::pair<float, float> bounds = get_activation_layer_test_bounds<float>(act_info.activation());
+ std::uniform_real_distribution<> distribution(bounds.first, bounds.second);
+ library->fill(NEAccessor(src), distribution, 0);
+ break;
}
- else
+ default:
{
- std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<int16_t>(act_info.activation(), fixed_point_position);
+ ARM_COMPUTE_ERROR("Not supported");
+ break;
}
- std::uniform_int_distribution<> distribution(min_bound, max_bound);
- library->fill(NEAccessor(src), distribution, 0);
}
// Compute function
@@ -207,6 +223,27 @@ BOOST_DATA_TEST_CASE(Configuration, boost::unit_test::data::make({ false, true }
}
}
+#ifdef ARM_COMPUTE_ENABLE_FP16
+BOOST_AUTO_TEST_SUITE(Float16)
+BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
+BOOST_DATA_TEST_CASE(RunSmall, boost::unit_test::data::make({ false, true }) * SmallShapes() * boost::unit_test::data::make(DataType::F16) * ActivationFunctions() * boost::unit_test::data::make({ 0.5f, 1.f }),
+ in_place, shape, dt, act_function, alpha_beta)
+{
+ // Create activation layer info
+ const ActivationLayerInfo act_info(act_function, alpha_beta);
+
+ // Compute function
+ Tensor dst = compute_activation_layer(in_place, shape, dt, act_info);
+
+ // Compute reference
+ RawTensor ref_dst = Reference::compute_reference_activation_layer(shape, dt, act_info);
+
+ // Validate output
+ validate(NEAccessor(dst), ref_dst, activation_layer_tolerance(dt, act_function));
+}
+BOOST_AUTO_TEST_SUITE_END()
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
BOOST_AUTO_TEST_SUITE(Float)
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
BOOST_DATA_TEST_CASE(RunSmall, boost::unit_test::data::make({ false, true }) * SmallShapes() * CNNFloatDataTypes() * ActivationFunctions() * boost::unit_test::data::make({ 0.5f, 1.f }),
diff --git a/tests/validation/Reference.cpp b/tests/validation/Reference.cpp
index 3b429c1ee6..0ce25c5567 100644
--- a/tests/validation/Reference.cpp
+++ b/tests/validation/Reference.cpp
@@ -459,29 +459,44 @@ RawTensor Reference::compute_reference_activation_layer(const TensorShape &shape
RawTensor ref_src = library->get(shape, dt, 1, fixed_point_position);
RawTensor ref_dst = library->get(shape, dt, 1, fixed_point_position);
- // Fill reference
- if(dt == DataType::F32)
- {
- float min_bound = 0;
- float max_bound = 0;
- std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<float>(act_info.activation());
- std::uniform_real_distribution<> distribution(min_bound, max_bound);
- library->fill(ref_src, distribution, 0);
- }
- else
+ // Fill tensors
+ switch(dt)
{
- int min_bound = 0;
- int max_bound = 0;
- if(dt == DataType::QS8)
+ case DataType::QS8:
{
- std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<int8_t>(act_info.activation(), fixed_point_position);
+ const std::pair<int8_t, int8_t> bounds = get_activation_layer_test_bounds<int8_t>(act_info.activation(), fixed_point_position);
+ std::uniform_int_distribution<> distribution(bounds.first, bounds.second);
+ library->fill(ref_src, distribution, 0);
+ break;
}
- else
+ case DataType::QS16:
{
- std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<int16_t>(act_info.activation(), fixed_point_position);
+ const std::pair<int16_t, int16_t> bounds = get_activation_layer_test_bounds<int16_t>(act_info.activation(), fixed_point_position);
+ std::uniform_int_distribution<> distribution(bounds.first, bounds.second);
+ library->fill(ref_src, distribution, 0);
+ break;
+ }
+#ifdef ARM_COMPUTE_ENABLE_FP16
+ case DataType::F16:
+ {
+ const std::pair<float16_t, float16_t> bounds = get_activation_layer_test_bounds<float16_t>(act_info.activation());
+ std::uniform_real_distribution<> distribution(bounds.first, bounds.second);
+ library->fill(ref_src, distribution, 0);
+ break;
+ }
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+ case DataType::F32:
+ {
+ const std::pair<float, float> bounds = get_activation_layer_test_bounds<float>(act_info.activation());
+ std::uniform_real_distribution<> distribution(bounds.first, bounds.second);
+ library->fill(ref_src, distribution, 0);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Not supported");
+ break;
}
- std::uniform_int_distribution<> distribution(min_bound, max_bound);
- library->fill(ref_src, distribution, 0);
}
// Compute reference
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index 9e201e2f04..67dadd6da3 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -569,7 +569,7 @@ void box3x3(const Tensor<T> &in, Tensor<T> &out, BorderMode border_mode, T const
}
// Depth conversion
-template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_floating_point<T2>::value, int >::type = 0 >
+template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&is_floating_point<T2>::value, int >::type = 0 >
void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift)
{
using namespace fixed_point_arithmetic;
@@ -581,7 +581,7 @@ void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy,
}
}
-template < typename T1, typename T2, typename std::enable_if < std::is_floating_point<T1>::value &&std::is_integral<T2>::value, int >::type = 0 >
+template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&std::is_integral<T2>::value, int >::type = 0 >
void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift)
{
using namespace fixed_point_arithmetic;