From e49e26613264842f91d29a32be3a226a0d6adb42 Mon Sep 17 00:00:00 2001 From: Moritz Pflanzer Date: Fri, 21 Jul 2017 15:55:28 +0100 Subject: COMPMID-415: Use half_float library for F16 3RDPARTY_UPDATE Change-Id: Iee572e18d5b1df71300d738cc8690f49d7203d5c Reviewed-on: http://mpd-gerrit.cambridge.arm.com/81353 Tested-by: Kaizen Reviewed-by: Anthony Barbier --- 3rdparty | 2 +- src/core/CL/cl_kernels/gemm.cl | 2 +- tests/AssetsLibrary.h | 18 ++----- tests/Utils.h | 13 ++--- tests/validation/CL/ArithmeticAddition.cpp | 2 - tests/validation/CL/ConvolutionLayer.cpp | 7 ++- tests/validation/Helpers.h | 16 ++----- tests/validation/Reference.cpp | 8 ++-- tests/validation/TensorFactory.h | 23 ++++----- tests/validation/TensorOperations.h | 76 ++++++++++-------------------- tests/validation/Validation.cpp | 18 +++---- tests/validation/half.h | 33 +++++++++++++ tests/validation_new/Helpers.h | 49 +++++++++++++++++++ tests/validation_new/Validation.cpp | 9 +--- tests/validation_new/Validation.h | 8 ++-- 15 files changed, 148 insertions(+), 136 deletions(-) create mode 100644 tests/validation/half.h create mode 100644 tests/validation_new/Helpers.h diff --git a/3rdparty b/3rdparty index ca8086c345..473b15cd5e 160000 --- a/3rdparty +++ b/3rdparty @@ -1 +1 @@ -Subproject commit ca8086c3456a56ab7c963968281470691f5b9826 +Subproject commit 473b15cd5e41fc530b8619510ce45894b34739d2 diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index db15720ad0..00c73e7be0 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -754,7 +754,7 @@ __kernel void gemm_mm_f16(IMAGE_DECLARATION(src0), half8 c20 = 0.0f; half8 c30 = 0.0f; - for(; src_addr.s1 <= (end_row_mtx_b - 8); src_addr += (int2)(8, 16)) + for(; src_addr.s1 <= (end_row_mtx_b - 16); src_addr += (int2)(8, 16)) { /* Load values from matrix A (interleaved) and matrix B (transposed) */ half4 a0 = vload4(0, ((__global half *)src0_ptr) + src_addr.s0); diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h index 6ecaccbd76..58738f871d 100644 --- a/tests/AssetsLibrary.h +++ b/tests/AssetsLibrary.h @@ -24,10 +24,6 @@ #ifndef __ARM_COMPUTE_TEST_TENSOR_LIBRARY_H__ #define __ARM_COMPUTE_TEST_TENSOR_LIBRARY_H__ -#include "RawTensor.h" -#include "TensorCache.h" -#include "Utils.h" - #include "arm_compute/core/Coordinates.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" @@ -35,6 +31,10 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Window.h" +#include "tests/RawTensor.h" +#include "tests/TensorCache.h" +#include "tests/Utils.h" +#include "tests/validation/half.h" #include #include @@ -43,10 +43,6 @@ #include #include -#if ARM_COMPUTE_ENABLE_FP16 -#include // needed for float16_t -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - namespace arm_compute { namespace test @@ -476,9 +472,7 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t fill(tensor, distribution_s64, seed_offset); break; } -#if ARM_COMPUTE_ENABLE_FP16 case DataType::F16: -#endif /* ARM_COMPUTE_ENABLE_FP16 */ case DataType::F32: { // It doesn't make sense to check [-inf, inf], so hard code it to a big number @@ -567,14 +561,12 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t fill(tensor, distribution_s64, seed_offset); break; } -#if ARM_COMPUTE_ENABLE_FP16 case DataType::F16: { - std::uniform_real_distribution distribution_f16(low, high); + std::uniform_real_distribution distribution_f16(low, high); fill(tensor, distribution_f16, seed_offset); break; } -#endif /* ARM_COMPUTE_ENABLE_FP16 */ case DataType::F32: { ARM_COMPUTE_ERROR_ON(!(std::is_same::value)); diff --git a/tests/Utils.h b/tests/Utils.h index ad45bffe6e..0a58d41e35 100644 --- a/tests/Utils.h +++ b/tests/Utils.h @@ -31,6 +31,7 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" #include "support/ToolchainSupport.h" +#include "tests/validation/half.h" #include #include @@ -40,10 +41,6 @@ #include #include -#ifdef ARM_COMPUTE_ENABLE_FP16 -#include // needed for float16_t -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - namespace arm_compute { namespace test @@ -100,9 +97,7 @@ template <> struct promote { using type = int32_t; }; template <> struct promote { using type = uint64_t; }; template <> struct promote { using type = int64_t; }; template <> struct promote { using type = float; }; -#ifdef ARM_COMPUTE_ENABLE_FP16 -template <> struct promote { using type = float16_t; }; -#endif /* ARM_COMPUTE_ENABLE_FP16 */ +template <> struct promote { using type = half_float::half; }; template @@ -248,11 +243,9 @@ void store_value_with_data_type(void *ptr, T value, DataType data_type) case DataType::S64: *reinterpret_cast(ptr) = value; break; -#if ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - *reinterpret_cast(ptr) = value; + *reinterpret_cast(ptr) = value; break; -#endif /* ARM_COMPUTE_ENABLE_FP16 */ case DataType::F32: *reinterpret_cast(ptr) = value; break; diff --git a/tests/validation/CL/ArithmeticAddition.cpp b/tests/validation/CL/ArithmeticAddition.cpp index 66704761cd..fc1bf5905d 100644 --- a/tests/validation/CL/ArithmeticAddition.cpp +++ b/tests/validation/CL/ArithmeticAddition.cpp @@ -244,7 +244,6 @@ BOOST_DATA_TEST_CASE(RunLarge, LargeShapes() * ConvertPolicies() * boost::unit_t BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END() -#ifdef ARM_COMPUTE_ENABLE_FP16 BOOST_AUTO_TEST_SUITE(F16) BOOST_DATA_TEST_CASE(RunSmall, SmallShapes(), shape) { @@ -258,7 +257,6 @@ BOOST_DATA_TEST_CASE(RunSmall, SmallShapes(), shape) validate(CLAccessor(dst), ref_dst); } BOOST_AUTO_TEST_SUITE_END() -#endif /* ARM_COMPUTE_ENABLE_FP16 */ BOOST_AUTO_TEST_SUITE(F32) BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly")) diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp index 6123571de1..a3d7140f99 100644 --- a/tests/validation/CL/ConvolutionLayer.cpp +++ b/tests/validation/CL/ConvolutionLayer.cpp @@ -45,6 +45,7 @@ using namespace arm_compute::test::validation; namespace { +const float tolerance_f16 = 1.f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ const float tolerance_f32 = 1e-03f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ const float tolerance_q = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */ @@ -73,7 +74,7 @@ CLTensor compute_convolution_layer(const TensorShape &input_shape, const TensorS BOOST_TEST(!dst.info()->is_resizable()); // Fill tensors - if(dt == DataType::F32) + if(dt == DataType::F32 || dt == DataType::F16) { std::uniform_real_distribution<> distribution(-1.0f, 1.0f); library->fill(CLAccessor(src), distribution, 0); @@ -134,7 +135,6 @@ BOOST_DATA_TEST_CASE(Configuration, validate(dst.info()->valid_region(), dst_valid_region); } -#ifdef ARM_COMPUTE_ENABLE_FP16 BOOST_AUTO_TEST_SUITE(Float16) BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit")) BOOST_DATA_TEST_CASE(SmallConvolutionLayer, @@ -148,10 +148,9 @@ BOOST_DATA_TEST_CASE(SmallConvolutionLayer, RawTensor ref_dst = Reference::compute_reference_convolution_layer(conv_set.src_shape, conv_set.weights_shape, conv_set.bias_shape, conv_set.dst_shape, dt, conv_set.info, 0); // Validate output - validate(CLAccessor(dst), ref_dst, tolerance_f32); + validate(CLAccessor(dst), ref_dst, tolerance_f16); } BOOST_AUTO_TEST_SUITE_END() -#endif /* ARM_COMPUTE_ENABLE_FP16 */ BOOST_AUTO_TEST_SUITE(Float) BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit")) diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h index 191e32813c..2793c22147 100644 --- a/tests/validation/Helpers.h +++ b/tests/validation/Helpers.h @@ -24,21 +24,17 @@ #ifndef __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ #define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ -#include "ILutAccessor.h" -#include "Types.h" -#include "ValidationUserConfiguration.h" - #include "arm_compute/core/Types.h" +#include "tests/ILutAccessor.h" +#include "tests/Types.h" +#include "tests/validation/ValidationUserConfiguration.h" +#include "tests/validation/half.h" #include #include #include #include -#ifdef ARM_COMPUTE_ENABLE_FP16 -#include -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - namespace arm_compute { namespace test @@ -56,9 +52,7 @@ template inline std::pair get_activation_layer_test_bounds(ActivationLayerInfo::ActivationFunction activation, int fixed_point_position = 1) { bool is_float = std::is_same::value; -#ifdef ARM_COMPUTE_ENABLE_FP16 - is_float = is_float || std::is_same::value; -#endif /* ARM_COMPUTE_ENABLE_FP16 */ + is_float = is_float || std::is_same::value; std::pair bounds; diff --git a/tests/validation/Reference.cpp b/tests/validation/Reference.cpp index 1db3c3f5fb..b94a0e5195 100644 --- a/tests/validation/Reference.cpp +++ b/tests/validation/Reference.cpp @@ -476,15 +476,13 @@ RawTensor Reference::compute_reference_activation_layer(const TensorShape &shape library->fill(ref_src, distribution, 0); break; } -#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: { - const std::pair bounds = get_activation_layer_test_bounds(act_info.activation()); + const std::pair bounds = get_activation_layer_test_bounds(act_info.activation()); std::uniform_real_distribution<> distribution(bounds.first, bounds.second); library->fill(ref_src, distribution, 0); break; } -#endif /* ARM_COMPUTE_ENABLE_FP16 */ case DataType::F32: { const std::pair bounds = get_activation_layer_test_bounds(act_info.activation()); @@ -604,9 +602,9 @@ RawTensor Reference::compute_reference_depth_concatenate_layer(const std::vector TensorShape dst_shape = calculate_depth_concatenate_shape(shapes); // Create tensors - for(unsigned int i = 0; i < shapes.size(); ++i) + for(const auto &shape : shapes) { - ref_srcs.push_back(support::cpp14::make_unique(RawTensor(shapes[i], dt, 1, fixed_point_position))); + ref_srcs.push_back(support::cpp14::make_unique(shape, dt, 1, fixed_point_position)); } RawTensor ref_dst(dst_shape, dt, 1, fixed_point_position); diff --git a/tests/validation/TensorFactory.h b/tests/validation/TensorFactory.h index 2f33dd283d..a3bb5f9615 100644 --- a/tests/validation/TensorFactory.h +++ b/tests/validation/TensorFactory.h @@ -24,29 +24,24 @@ #ifndef __ARM_COMPUTE_TEST_TENSOR_FACTORY_H__ #define __ARM_COMPUTE_TEST_TENSOR_FACTORY_H__ -#include "RawTensor.h" -#include "Tensor.h" #include "arm_compute/core/Error.h" +#include "tests/RawTensor.h" +#include "tests/validation/Tensor.h" +#include "tests/validation/half.h" #include "boost_wrapper.h" -#if ARM_COMPUTE_ENABLE_FP16 -#include // needed for float16_t -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - namespace arm_compute { namespace test { namespace validation { -using TensorVariant = boost::variant < Tensor, Tensor, +using TensorVariant = boost::variant, Tensor, Tensor, Tensor, Tensor, Tensor, -#ifdef ARM_COMPUTE_ENABLE_FP16 - Tensor, -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - Tensor>; + Tensor, + Tensor>; /** Helper to create a constant type if the passed reference is constant. */ template @@ -95,12 +90,10 @@ public: using value_type_s32 = typename match_const::type; v = Tensor(shape, dt, fixed_point_position, reinterpret_cast(data)); break; -#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - using value_type_f16 = typename match_const::type; - v = Tensor(shape, dt, fixed_point_position, reinterpret_cast(data)); + using value_type_f16 = typename match_const::type; + v = Tensor(shape, dt, fixed_point_position, reinterpret_cast(data)); break; -#endif /* ARM_COMPUTE_ENABLE_FP16 */ case DataType::F32: using value_type_f32 = typename match_const::type; v = Tensor(shape, dt, fixed_point_position, reinterpret_cast(data)); diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h index 319047816c..359dfe8d03 100644 --- a/tests/validation/TensorOperations.h +++ b/tests/validation/TensorOperations.h @@ -24,18 +24,15 @@ #ifndef __ARM_COMPUTE_TEST_TENSOR_OPERATIONS_H__ #define __ARM_COMPUTE_TEST_TENSOR_OPERATIONS_H__ -#include "FixedPoint.h" -#include "Tensor.h" -#include "Types.h" -#include "Utils.h" -#include "support/ToolchainSupport.h" - -#include "FixedPoint.h" -#include "Types.h" #include "arm_compute/core/FixedPoint.h" #include "arm_compute/core/Types.h" +#include "support/ToolchainSupport.h" +#include "tests/Types.h" +#include "tests/Utils.h" #include "tests/validation/FixedPoint.h" +#include "tests/validation/Tensor.h" #include "tests/validation/ValidationUserConfiguration.h" +#include "tests/validation/half.h" #include #include @@ -44,26 +41,6 @@ #include #include -#if ARM_COMPUTE_ENABLE_FP16 -//Beware! most std templates acting on types don't work with the data type float16_t -namespace std -{ -template <> -class numeric_limits -{ -public: - static float16_t lowest() - { - return -std::numeric_limits::max(); // -inf - }; - static float16_t max() - { - return std::numeric_limits::max(); // +inf - }; -}; -} -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - namespace arm_compute { namespace test @@ -77,11 +54,8 @@ namespace template struct is_floating_point : std::integral_constant < bool, - std::is_same::type>::value || -#ifdef ARM_COMPUTE_ENABLE_FP16 - std::is_same::type>::value || -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - std::is_same::type>::value || std::is_same::type>::value > + std::is_same::type>::value || std::is_same::type>::value + || std::is_same::type>::value || std::is_same::type>::value > { }; @@ -184,7 +158,7 @@ void vector_matrix_multiply(const T *in, const T *weights, const T *bias, T *out { for(int x = 0; x < cols_weights; ++x) { - T acc = 0.0f; + T acc(0); for(int y = 0; y < rows_weights; ++y) { acc += in[y] * weights[x + y * cols_weights]; @@ -456,8 +430,8 @@ void absolute_difference(const Tensor &in1, const Tensor &in2, Tensor(in1[i]) - static_cast(in2[i])); - out[i] = saturate_cast(val); + intermediate_type val(std::abs(static_cast(in1[i]) - static_cast(in2[i]))); + out[i] = saturate_cast(val); } } @@ -708,7 +682,7 @@ void gemm(const Tensor &in1, const Tensor &in2, const Tensor &in3, Tens { for(int c = 0; c < N; ++c) { - T acc = 0.0f; + T acc(0); for(int k = 0; k < K; ++k) { @@ -967,10 +941,10 @@ void activation_layer(const Tensor &in, Tensor &out, ActivationLayerInfo a out[i] = static_cast(1) / (static_cast(1) + std::exp(-x)); break; case ActivationLayerInfo::ActivationFunction::RELU: - out[i] = std::max(0, x); + out[i] = std::max(static_cast(0), x); break; case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU: - out[i] = std::min(a, std::max(0, x)); + out[i] = std::min(a, std::max(static_cast(0), x)); break; case ActivationLayerInfo::ActivationFunction::LEAKY_RELU: out[i] = (x > 0) ? x : a * x; @@ -1519,16 +1493,16 @@ void pooling_layer(const Tensor &in, Tensor &out, PoolingLayerInfo pool_in { for(int w = 0; w < pooled_w; ++w) { - T avg_val = 0; - int wstart = w * pool_stride_x - pad_x; - int hstart = h * pool_stride_y - pad_y; - int wend = std::min(wstart + pool_size, w_in + pad_x); - int hend = std::min(hstart + pool_size, h_in + pad_y); - int pool = (hend - hstart) * (wend - wstart); - wstart = std::max(wstart, 0); - hstart = std::max(hstart, 0); - wend = std::min(wend, w_in); - hend = std::min(hend, h_in); + T avg_val(0); + int wstart = w * pool_stride_x - pad_x; + int hstart = h * pool_stride_y - pad_y; + int wend = std::min(wstart + pool_size, w_in + pad_x); + int hend = std::min(hstart + pool_size, h_in + pad_y); + int pool = (hend - hstart) * (wend - wstart); + wstart = std::max(wstart, 0); + hstart = std::max(hstart, 0); + wend = std::min(wend, w_in); + hend = std::min(hend, h_in); if(is_floating_point::value) { for(int y = hstart; y < hend; ++y) @@ -1652,7 +1626,7 @@ void softmax_layer(const Tensor &in, Tensor &out) } // Regularize - T sum = 0; + T sum(0); for(int c = 0; c < cols; ++c) { const T res = exp(in[r * cols + c] - max); @@ -1661,7 +1635,7 @@ void softmax_layer(const Tensor &in, Tensor &out) } // Normalize - const T norm_val = 1 / sum; + const T norm_val = static_cast(1) / sum; for(int c = 0; c < cols; ++c) { out[r * cols + c] *= norm_val; diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp index 14ee98a96b..a13eeb0b85 100644 --- a/tests/validation/Validation.cpp +++ b/tests/validation/Validation.cpp @@ -23,16 +23,16 @@ */ #include "Validation.h" -#include "IAccessor.h" -#include "RawTensor.h" -#include "TypePrinter.h" -#include "Utils.h" - #include "arm_compute/core/Coordinates.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/FixedPoint.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/runtime/Tensor.h" +#include "tests/IAccessor.h" +#include "tests/RawTensor.h" +#include "tests/TypePrinter.h" +#include "tests/Utils.h" +#include "tests/validation/half.h" #include #include @@ -40,10 +40,6 @@ #include #include -#ifdef ARM_COMPUTE_ENABLE_FP16 -#include // needed for float16_t -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - namespace arm_compute { namespace test @@ -88,10 +84,8 @@ double get_double_data(const void *ptr, DataType data_type) return *reinterpret_cast(ptr); case DataType::S64: return *reinterpret_cast(ptr); -#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return *reinterpret_cast(ptr); -#endif /* ARM_COMPUTE_ENABLE_FP16 */ + return *reinterpret_cast(ptr); case DataType::F32: return *reinterpret_cast(ptr); case DataType::F64: diff --git a/tests/validation/half.h b/tests/validation/half.h new file mode 100644 index 0000000000..fb2235aad9 --- /dev/null +++ b/tests/validation/half.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_TEST_HALF_H__ +#define __ARM_COMPUTE_TEST_HALF_H__ + +#ifdef __ANDROID__ +// Android toolchain is broken and doesn't support all CPP11 math functions. +#define HALF_ENABLE_CPP11_CMATH 0 +#endif /* __ANDROID__ */ + +#include "half/half.hpp" +#endif /* __ARM_COMPUTE_TEST_HALF_H__ */ diff --git a/tests/validation_new/Helpers.h b/tests/validation_new/Helpers.h new file mode 100644 index 0000000000..e25b684c11 --- /dev/null +++ b/tests/validation_new/Helpers.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ +#define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ + +#include "tests/validation/half.h" + +#include + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +template +struct is_floating_point : public std::is_floating_point +{ +}; + +template <> +struct is_floating_point : public std::true_type +{ +}; +} // namespace validation +} // namespace test +} // namespace arm_compute +#endif /* __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ */ diff --git a/tests/validation_new/Validation.cpp b/tests/validation_new/Validation.cpp index 8ab8274d2a..9071663e7c 100644 --- a/tests/validation_new/Validation.cpp +++ b/tests/validation_new/Validation.cpp @@ -27,16 +27,13 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/runtime/Tensor.h" +#include "tests/validation/half.h" #include #include #include #include -#ifdef ARM_COMPUTE_ENABLE_FP16 -#include // needed for float16_t -#endif /* ARM_COMPUTE_ENABLE_FP16 */ - namespace arm_compute { namespace test @@ -81,10 +78,8 @@ double get_double_data(const void *ptr, DataType data_type) return *reinterpret_cast(ptr); case DataType::S64: return *reinterpret_cast(ptr); -#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return *reinterpret_cast(ptr); -#endif /* ARM_COMPUTE_ENABLE_FP16 */ + return *reinterpret_cast(ptr); case DataType::F32: return *reinterpret_cast(ptr); case DataType::F64: diff --git a/tests/validation_new/Validation.h b/tests/validation_new/Validation.h index 5e947caf8d..7db7b00886 100644 --- a/tests/validation_new/Validation.h +++ b/tests/validation_new/Validation.h @@ -85,8 +85,8 @@ void validate(const arm_compute::PaddingSize &padding, const arm_compute::Paddin * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by * other test cases. */ -template -void validate(const IAccessor &tensor, const SimpleTensor &reference, U tolerance_value = 0, float tolerance_number = 0.f); +template +void validate(const IAccessor &tensor, const SimpleTensor &reference, U tolerance_value = U(0), float tolerance_number = 0.f); /** Validate tensors with valid region. * @@ -98,8 +98,8 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, U toler * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by * other test cases. */ -template -void validate(const IAccessor &tensor, const SimpleTensor &reference, const ValidRegion &valid_region, U tolerance_value = 0, float tolerance_number = 0.f); +template +void validate(const IAccessor &tensor, const SimpleTensor &reference, const ValidRegion &valid_region, U tolerance_value = U(0), float tolerance_number = 0.f); /** Validate tensors against constant value. * -- cgit v1.2.1