aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMoritz Pflanzer <moritz.pflanzer@arm.com>2017-07-21 15:55:28 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commite49e26613264842f91d29a32be3a226a0d6adb42 (patch)
tree78d88bded1f178d06b9dbfe3950ba716ef229599
parent27b386cb7596542a3296c32e41f7a5168b4d53be (diff)
downloadComputeLibrary-e49e26613264842f91d29a32be3a226a0d6adb42.tar.gz
COMPMID-415: Use half_float library for F16
3RDPARTY_UPDATE Change-Id: Iee572e18d5b1df71300d738cc8690f49d7203d5c Reviewed-on: http://mpd-gerrit.cambridge.arm.com/81353 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
m---------3rdparty0
-rw-r--r--src/core/CL/cl_kernels/gemm.cl2
-rw-r--r--tests/AssetsLibrary.h18
-rw-r--r--tests/Utils.h13
-rw-r--r--tests/validation/CL/ArithmeticAddition.cpp2
-rw-r--r--tests/validation/CL/ConvolutionLayer.cpp7
-rw-r--r--tests/validation/Helpers.h16
-rw-r--r--tests/validation/Reference.cpp8
-rw-r--r--tests/validation/TensorFactory.h23
-rw-r--r--tests/validation/TensorOperations.h76
-rw-r--r--tests/validation/Validation.cpp18
-rw-r--r--tests/validation/half.h33
-rw-r--r--tests/validation_new/Helpers.h49
-rw-r--r--tests/validation_new/Validation.cpp9
-rw-r--r--tests/validation_new/Validation.h8
15 files changed, 147 insertions, 135 deletions
diff --git a/3rdparty b/3rdparty
-Subproject ca8086c3456a56ab7c963968281470691f5b982
+Subproject 473b15cd5e41fc530b8619510ce45894b34739d
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index db15720ad0..00c73e7be0 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -754,7 +754,7 @@ __kernel void gemm_mm_f16(IMAGE_DECLARATION(src0),
half8 c20 = 0.0f;
half8 c30 = 0.0f;
- for(; src_addr.s1 <= (end_row_mtx_b - 8); src_addr += (int2)(8, 16))
+ for(; src_addr.s1 <= (end_row_mtx_b - 16); src_addr += (int2)(8, 16))
{
/* Load values from matrix A (interleaved) and matrix B (transposed) */
half4 a0 = vload4(0, ((__global half *)src0_ptr) + src_addr.s0);
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h
index 6ecaccbd76..58738f871d 100644
--- a/tests/AssetsLibrary.h
+++ b/tests/AssetsLibrary.h
@@ -24,10 +24,6 @@
#ifndef __ARM_COMPUTE_TEST_TENSOR_LIBRARY_H__
#define __ARM_COMPUTE_TEST_TENSOR_LIBRARY_H__
-#include "RawTensor.h"
-#include "TensorCache.h"
-#include "Utils.h"
-
#include "arm_compute/core/Coordinates.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
@@ -35,6 +31,10 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Window.h"
+#include "tests/RawTensor.h"
+#include "tests/TensorCache.h"
+#include "tests/Utils.h"
+#include "tests/validation/half.h"
#include <algorithm>
#include <cstddef>
@@ -43,10 +43,6 @@
#include <string>
#include <type_traits>
-#if ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
namespace arm_compute
{
namespace test
@@ -476,9 +472,7 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t
fill(tensor, distribution_s64, seed_offset);
break;
}
-#if ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
case DataType::F32:
{
// It doesn't make sense to check [-inf, inf], so hard code it to a big number
@@ -567,14 +561,12 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t
fill(tensor, distribution_s64, seed_offset);
break;
}
-#if ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
- std::uniform_real_distribution<float_t> distribution_f16(low, high);
+ std::uniform_real_distribution<float> distribution_f16(low, high);
fill(tensor, distribution_f16, seed_offset);
break;
}
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
case DataType::F32:
{
ARM_COMPUTE_ERROR_ON(!(std::is_same<float, D>::value));
diff --git a/tests/Utils.h b/tests/Utils.h
index ad45bffe6e..0a58d41e35 100644
--- a/tests/Utils.h
+++ b/tests/Utils.h
@@ -31,6 +31,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "support/ToolchainSupport.h"
+#include "tests/validation/half.h"
#include <cmath>
#include <cstddef>
@@ -40,10 +41,6 @@
#include <string>
#include <type_traits>
-#ifdef ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
namespace arm_compute
{
namespace test
@@ -100,9 +97,7 @@ template <> struct promote<int16_t> { using type = int32_t; };
template <> struct promote<uint32_t> { using type = uint64_t; };
template <> struct promote<int32_t> { using type = int64_t; };
template <> struct promote<float> { using type = float; };
-#ifdef ARM_COMPUTE_ENABLE_FP16
-template <> struct promote<float16_t> { using type = float16_t; };
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
+template <> struct promote<half_float::half> { using type = half_float::half; };
template <typename T>
@@ -248,11 +243,9 @@ void store_value_with_data_type(void *ptr, T value, DataType data_type)
case DataType::S64:
*reinterpret_cast<int64_t *>(ptr) = value;
break;
-#if ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
- *reinterpret_cast<float16_t *>(ptr) = value;
+ *reinterpret_cast<half_float::half *>(ptr) = value;
break;
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
case DataType::F32:
*reinterpret_cast<float *>(ptr) = value;
break;
diff --git a/tests/validation/CL/ArithmeticAddition.cpp b/tests/validation/CL/ArithmeticAddition.cpp
index 66704761cd..fc1bf5905d 100644
--- a/tests/validation/CL/ArithmeticAddition.cpp
+++ b/tests/validation/CL/ArithmeticAddition.cpp
@@ -244,7 +244,6 @@ BOOST_DATA_TEST_CASE(RunLarge, LargeShapes() * ConvertPolicies() * boost::unit_t
BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE_END()
-#ifdef ARM_COMPUTE_ENABLE_FP16
BOOST_AUTO_TEST_SUITE(F16)
BOOST_DATA_TEST_CASE(RunSmall, SmallShapes(), shape)
{
@@ -258,7 +257,6 @@ BOOST_DATA_TEST_CASE(RunSmall, SmallShapes(), shape)
validate(CLAccessor(dst), ref_dst);
}
BOOST_AUTO_TEST_SUITE_END()
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
BOOST_AUTO_TEST_SUITE(F32)
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly"))
diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp
index 6123571de1..a3d7140f99 100644
--- a/tests/validation/CL/ConvolutionLayer.cpp
+++ b/tests/validation/CL/ConvolutionLayer.cpp
@@ -45,6 +45,7 @@ using namespace arm_compute::test::validation;
namespace
{
+const float tolerance_f16 = 1.f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
const float tolerance_f32 = 1e-03f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
const float tolerance_q = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */
@@ -73,7 +74,7 @@ CLTensor compute_convolution_layer(const TensorShape &input_shape, const TensorS
BOOST_TEST(!dst.info()->is_resizable());
// Fill tensors
- if(dt == DataType::F32)
+ if(dt == DataType::F32 || dt == DataType::F16)
{
std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
library->fill(CLAccessor(src), distribution, 0);
@@ -134,7 +135,6 @@ BOOST_DATA_TEST_CASE(Configuration,
validate(dst.info()->valid_region(), dst_valid_region);
}
-#ifdef ARM_COMPUTE_ENABLE_FP16
BOOST_AUTO_TEST_SUITE(Float16)
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
BOOST_DATA_TEST_CASE(SmallConvolutionLayer,
@@ -148,10 +148,9 @@ BOOST_DATA_TEST_CASE(SmallConvolutionLayer,
RawTensor ref_dst = Reference::compute_reference_convolution_layer(conv_set.src_shape, conv_set.weights_shape, conv_set.bias_shape, conv_set.dst_shape, dt, conv_set.info, 0);
// Validate output
- validate(CLAccessor(dst), ref_dst, tolerance_f32);
+ validate(CLAccessor(dst), ref_dst, tolerance_f16);
}
BOOST_AUTO_TEST_SUITE_END()
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
BOOST_AUTO_TEST_SUITE(Float)
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h
index 191e32813c..2793c22147 100644
--- a/tests/validation/Helpers.h
+++ b/tests/validation/Helpers.h
@@ -24,21 +24,17 @@
#ifndef __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
#define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
-#include "ILutAccessor.h"
-#include "Types.h"
-#include "ValidationUserConfiguration.h"
-
#include "arm_compute/core/Types.h"
+#include "tests/ILutAccessor.h"
+#include "tests/Types.h"
+#include "tests/validation/ValidationUserConfiguration.h"
+#include "tests/validation/half.h"
#include <random>
#include <type_traits>
#include <utility>
#include <vector>
-#ifdef ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h>
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
namespace arm_compute
{
namespace test
@@ -56,9 +52,7 @@ template <typename T>
inline std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::ActivationFunction activation, int fixed_point_position = 1)
{
bool is_float = std::is_same<T, float>::value;
-#ifdef ARM_COMPUTE_ENABLE_FP16
- is_float = is_float || std::is_same<T, float16_t>::value;
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
+ is_float = is_float || std::is_same<T, half_float::half>::value;
std::pair<T, T> bounds;
diff --git a/tests/validation/Reference.cpp b/tests/validation/Reference.cpp
index 1db3c3f5fb..b94a0e5195 100644
--- a/tests/validation/Reference.cpp
+++ b/tests/validation/Reference.cpp
@@ -476,15 +476,13 @@ RawTensor Reference::compute_reference_activation_layer(const TensorShape &shape
library->fill(ref_src, distribution, 0);
break;
}
-#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
- const std::pair<float16_t, float16_t> bounds = get_activation_layer_test_bounds<float16_t>(act_info.activation());
+ const std::pair<half_float::half, half_float::half> bounds = get_activation_layer_test_bounds<half_float::half>(act_info.activation());
std::uniform_real_distribution<> distribution(bounds.first, bounds.second);
library->fill(ref_src, distribution, 0);
break;
}
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
case DataType::F32:
{
const std::pair<float, float> bounds = get_activation_layer_test_bounds<float>(act_info.activation());
@@ -604,9 +602,9 @@ RawTensor Reference::compute_reference_depth_concatenate_layer(const std::vector
TensorShape dst_shape = calculate_depth_concatenate_shape(shapes);
// Create tensors
- for(unsigned int i = 0; i < shapes.size(); ++i)
+ for(const auto &shape : shapes)
{
- ref_srcs.push_back(support::cpp14::make_unique<RawTensor>(RawTensor(shapes[i], dt, 1, fixed_point_position)));
+ ref_srcs.push_back(support::cpp14::make_unique<RawTensor>(shape, dt, 1, fixed_point_position));
}
RawTensor ref_dst(dst_shape, dt, 1, fixed_point_position);
diff --git a/tests/validation/TensorFactory.h b/tests/validation/TensorFactory.h
index 2f33dd283d..a3bb5f9615 100644
--- a/tests/validation/TensorFactory.h
+++ b/tests/validation/TensorFactory.h
@@ -24,29 +24,24 @@
#ifndef __ARM_COMPUTE_TEST_TENSOR_FACTORY_H__
#define __ARM_COMPUTE_TEST_TENSOR_FACTORY_H__
-#include "RawTensor.h"
-#include "Tensor.h"
#include "arm_compute/core/Error.h"
+#include "tests/RawTensor.h"
+#include "tests/validation/Tensor.h"
+#include "tests/validation/half.h"
#include "boost_wrapper.h"
-#if ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
namespace arm_compute
{
namespace test
{
namespace validation
{
-using TensorVariant = boost::variant < Tensor<uint8_t>, Tensor<int8_t>,
+using TensorVariant = boost::variant<Tensor<uint8_t>, Tensor<int8_t>,
Tensor<uint16_t>, Tensor<int16_t>,
Tensor<uint32_t>, Tensor<int32_t>,
-#ifdef ARM_COMPUTE_ENABLE_FP16
- Tensor<float16_t>,
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
- Tensor<float >>;
+ Tensor<half_float::half>,
+ Tensor<float>>;
/** Helper to create a constant type if the passed reference is constant. */
template <typename R, typename T>
@@ -95,12 +90,10 @@ public:
using value_type_s32 = typename match_const<R, int32_t>::type;
v = Tensor<int32_t>(shape, dt, fixed_point_position, reinterpret_cast<value_type_s32 *>(data));
break;
-#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
- using value_type_f16 = typename match_const<R, float16_t>::type;
- v = Tensor<float16_t>(shape, dt, fixed_point_position, reinterpret_cast<value_type_f16 *>(data));
+ using value_type_f16 = typename match_const<R, half_float::half>::type;
+ v = Tensor<half_float::half>(shape, dt, fixed_point_position, reinterpret_cast<value_type_f16 *>(data));
break;
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
case DataType::F32:
using value_type_f32 = typename match_const<R, float>::type;
v = Tensor<float>(shape, dt, fixed_point_position, reinterpret_cast<value_type_f32 *>(data));
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index 319047816c..359dfe8d03 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -24,18 +24,15 @@
#ifndef __ARM_COMPUTE_TEST_TENSOR_OPERATIONS_H__
#define __ARM_COMPUTE_TEST_TENSOR_OPERATIONS_H__
-#include "FixedPoint.h"
-#include "Tensor.h"
-#include "Types.h"
-#include "Utils.h"
-#include "support/ToolchainSupport.h"
-
-#include "FixedPoint.h"
-#include "Types.h"
#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Types.h"
+#include "support/ToolchainSupport.h"
+#include "tests/Types.h"
+#include "tests/Utils.h"
#include "tests/validation/FixedPoint.h"
+#include "tests/validation/Tensor.h"
#include "tests/validation/ValidationUserConfiguration.h"
+#include "tests/validation/half.h"
#include <algorithm>
#include <array>
@@ -44,26 +41,6 @@
#include <string>
#include <vector>
-#if ARM_COMPUTE_ENABLE_FP16
-//Beware! most std templates acting on types don't work with the data type float16_t
-namespace std
-{
-template <>
-class numeric_limits<float16_t>
-{
-public:
- static float16_t lowest()
- {
- return -std::numeric_limits<float>::max(); // -inf
- };
- static float16_t max()
- {
- return std::numeric_limits<float>::max(); // +inf
- };
-};
-}
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
namespace arm_compute
{
namespace test
@@ -77,11 +54,8 @@ namespace
template <class T>
struct is_floating_point
: std::integral_constant < bool,
- std::is_same<float, typename std::remove_cv<T>::type>::value ||
-#ifdef ARM_COMPUTE_ENABLE_FP16
- std::is_same<float16_t, typename std::remove_cv<T>::type>::value ||
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
- std::is_same<double, typename std::remove_cv<T>::type>::value || std::is_same<long double, typename std::remove_cv<T>::type>::value >
+ std::is_same<float, typename std::remove_cv<T>::type>::value || std::is_same<half_float::half, typename std::remove_cv<T>::type>::value
+ || std::is_same<double, typename std::remove_cv<T>::type>::value || std::is_same<long double, typename std::remove_cv<T>::type>::value >
{
};
@@ -184,7 +158,7 @@ void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out
{
for(int x = 0; x < cols_weights; ++x)
{
- T acc = 0.0f;
+ T acc(0);
for(int y = 0; y < rows_weights; ++y)
{
acc += in[y] * weights[x + y * cols_weights];
@@ -456,8 +430,8 @@ void absolute_difference(const Tensor<T1> &in1, const Tensor<T2> &in2, Tensor<T3
for(int i = 0; i < in1.num_elements(); ++i)
{
- intermediate_type val = std::abs(static_cast<intermediate_type>(in1[i]) - static_cast<intermediate_type>(in2[i]));
- out[i] = saturate_cast<T3>(val);
+ intermediate_type val(std::abs(static_cast<intermediate_type>(in1[i]) - static_cast<intermediate_type>(in2[i])));
+ out[i] = saturate_cast<T3>(val);
}
}
@@ -708,7 +682,7 @@ void gemm(const Tensor<T> &in1, const Tensor<T> &in2, const Tensor<T> &in3, Tens
{
for(int c = 0; c < N; ++c)
{
- T acc = 0.0f;
+ T acc(0);
for(int k = 0; k < K; ++k)
{
@@ -967,10 +941,10 @@ void activation_layer(const Tensor<T> &in, Tensor<T> &out, ActivationLayerInfo a
out[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
break;
case ActivationLayerInfo::ActivationFunction::RELU:
- out[i] = std::max<T>(0, x);
+ out[i] = std::max(static_cast<T>(0), x);
break;
case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
- out[i] = std::min<T>(a, std::max<T>(0, x));
+ out[i] = std::min<T>(a, std::max(static_cast<T>(0), x));
break;
case ActivationLayerInfo::ActivationFunction::LEAKY_RELU:
out[i] = (x > 0) ? x : a * x;
@@ -1519,16 +1493,16 @@ void pooling_layer(const Tensor<T> &in, Tensor<T> &out, PoolingLayerInfo pool_in
{
for(int w = 0; w < pooled_w; ++w)
{
- T avg_val = 0;
- int wstart = w * pool_stride_x - pad_x;
- int hstart = h * pool_stride_y - pad_y;
- int wend = std::min(wstart + pool_size, w_in + pad_x);
- int hend = std::min(hstart + pool_size, h_in + pad_y);
- int pool = (hend - hstart) * (wend - wstart);
- wstart = std::max(wstart, 0);
- hstart = std::max(hstart, 0);
- wend = std::min(wend, w_in);
- hend = std::min(hend, h_in);
+ T avg_val(0);
+ int wstart = w * pool_stride_x - pad_x;
+ int hstart = h * pool_stride_y - pad_y;
+ int wend = std::min(wstart + pool_size, w_in + pad_x);
+ int hend = std::min(hstart + pool_size, h_in + pad_y);
+ int pool = (hend - hstart) * (wend - wstart);
+ wstart = std::max(wstart, 0);
+ hstart = std::max(hstart, 0);
+ wend = std::min(wend, w_in);
+ hend = std::min(hend, h_in);
if(is_floating_point<T>::value)
{
for(int y = hstart; y < hend; ++y)
@@ -1652,7 +1626,7 @@ void softmax_layer(const Tensor<T> &in, Tensor<T> &out)
}
// Regularize
- T sum = 0;
+ T sum(0);
for(int c = 0; c < cols; ++c)
{
const T res = exp(in[r * cols + c] - max);
@@ -1661,7 +1635,7 @@ void softmax_layer(const Tensor<T> &in, Tensor<T> &out)
}
// Normalize
- const T norm_val = 1 / sum;
+ const T norm_val = static_cast<T>(1) / sum;
for(int c = 0; c < cols; ++c)
{
out[r * cols + c] *= norm_val;
diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp
index 14ee98a96b..a13eeb0b85 100644
--- a/tests/validation/Validation.cpp
+++ b/tests/validation/Validation.cpp
@@ -23,16 +23,16 @@
*/
#include "Validation.h"
-#include "IAccessor.h"
-#include "RawTensor.h"
-#include "TypePrinter.h"
-#include "Utils.h"
-
#include "arm_compute/core/Coordinates.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/runtime/Tensor.h"
+#include "tests/IAccessor.h"
+#include "tests/RawTensor.h"
+#include "tests/TypePrinter.h"
+#include "tests/Utils.h"
+#include "tests/validation/half.h"
#include <array>
#include <cmath>
@@ -40,10 +40,6 @@
#include <cstdint>
#include <iomanip>
-#ifdef ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
namespace arm_compute
{
namespace test
@@ -88,10 +84,8 @@ double get_double_data(const void *ptr, DataType data_type)
return *reinterpret_cast<const uint64_t *>(ptr);
case DataType::S64:
return *reinterpret_cast<const int64_t *>(ptr);
-#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
- return *reinterpret_cast<const float16_t *>(ptr);
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
+ return *reinterpret_cast<const half_float::half *>(ptr);
case DataType::F32:
return *reinterpret_cast<const float *>(ptr);
case DataType::F64:
diff --git a/tests/validation/half.h b/tests/validation/half.h
new file mode 100644
index 0000000000..fb2235aad9
--- /dev/null
+++ b/tests/validation/half.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_TEST_HALF_H__
+#define __ARM_COMPUTE_TEST_HALF_H__
+
+#ifdef __ANDROID__
+// Android toolchain is broken and doesn't support all CPP11 math functions.
+#define HALF_ENABLE_CPP11_CMATH 0
+#endif /* __ANDROID__ */
+
+#include "half/half.hpp"
+#endif /* __ARM_COMPUTE_TEST_HALF_H__ */
diff --git a/tests/validation_new/Helpers.h b/tests/validation_new/Helpers.h
new file mode 100644
index 0000000000..e25b684c11
--- /dev/null
+++ b/tests/validation_new/Helpers.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
+#define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__
+
+#include "tests/validation/half.h"
+
+#include <type_traits>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+template <typename T>
+struct is_floating_point : public std::is_floating_point<T>
+{
+};
+
+template <>
+struct is_floating_point<half_float::half> : public std::true_type
+{
+};
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ */
diff --git a/tests/validation_new/Validation.cpp b/tests/validation_new/Validation.cpp
index 8ab8274d2a..9071663e7c 100644
--- a/tests/validation_new/Validation.cpp
+++ b/tests/validation_new/Validation.cpp
@@ -27,16 +27,13 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/runtime/Tensor.h"
+#include "tests/validation/half.h"
#include <array>
#include <cmath>
#include <cstddef>
#include <cstdint>
-#ifdef ARM_COMPUTE_ENABLE_FP16
-#include <arm_fp16.h> // needed for float16_t
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
-
namespace arm_compute
{
namespace test
@@ -81,10 +78,8 @@ double get_double_data(const void *ptr, DataType data_type)
return *reinterpret_cast<const uint64_t *>(ptr);
case DataType::S64:
return *reinterpret_cast<const int64_t *>(ptr);
-#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
- return *reinterpret_cast<const float16_t *>(ptr);
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
+ return *reinterpret_cast<const half_float::half *>(ptr);
case DataType::F32:
return *reinterpret_cast<const float *>(ptr);
case DataType::F64:
diff --git a/tests/validation_new/Validation.h b/tests/validation_new/Validation.h
index 5e947caf8d..7db7b00886 100644
--- a/tests/validation_new/Validation.h
+++ b/tests/validation_new/Validation.h
@@ -85,8 +85,8 @@ void validate(const arm_compute::PaddingSize &padding, const arm_compute::Paddin
* reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
* other test cases.
*/
-template <typename T, typename U>
-void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = 0, float tolerance_number = 0.f);
+template <typename T, typename U = T>
+void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = U(0), float tolerance_number = 0.f);
/** Validate tensors with valid region.
*
@@ -98,8 +98,8 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U toler
* reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
* other test cases.
*/
-template <typename T, typename U>
-void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = 0, float tolerance_number = 0.f);
+template <typename T, typename U = T>
+void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = U(0), float tolerance_number = 0.f);
/** Validate tensors against constant value.
*