diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/CL/SoftmaxLayer.cpp | 12 | ||||
-rw-r--r-- | tests/validation/NEON/SoftmaxLayer.cpp | 12 | ||||
-rw-r--r-- | tests/validation/Validation.cpp | 4 | ||||
-rw-r--r-- | tests/validation/Validation.h | 16 |
4 files changed, 26 insertions, 18 deletions
diff --git a/tests/validation/CL/SoftmaxLayer.cpp b/tests/validation/CL/SoftmaxLayer.cpp index 6a22eb1bcc..8c143ecd96 100644 --- a/tests/validation/CL/SoftmaxLayer.cpp +++ b/tests/validation/CL/SoftmaxLayer.cpp @@ -48,7 +48,7 @@ RelativeTolerance<half_float::half> tolerance_f16(half_float::half(0.2)); RelativeTolerance<float> tolerance_f32(0.001f); /** Tolerance for fixed point operations */ -constexpr AbsoluteTolerance<int8_t> tolerance_fixed_point(2); +constexpr AbsoluteTolerance<int16_t> tolerance_fixed_point(2); /** CNN data types */ const auto CNNDataTypes = framework::dataset::make("DataType", @@ -145,15 +145,17 @@ TEST_SUITE_END() TEST_SUITE(QS16) // Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14 -FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", - DataType::QS16)), +FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataType", + DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_fixed_point); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", - DataType::QS16)), +FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataType", + DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp index 36f1881147..7ac7759c22 100644 --- a/tests/validation/NEON/SoftmaxLayer.cpp +++ b/tests/validation/NEON/SoftmaxLayer.cpp @@ -49,7 +49,7 @@ constexpr AbsoluteTolerance<float> tolerance_f32(0.000001f); constexpr AbsoluteTolerance<float> tolerance_f16(0.0001f); #endif /* ARM_COMPUTE_ENABLE_FP16*/ /** Tolerance for fixed point operations */ -constexpr AbsoluteTolerance<int8_t> tolerance_fixed_point(2); +constexpr AbsoluteTolerance<int16_t> tolerance_fixed_point(2); /** CNN data types */ const auto CNNDataTypes = framework::dataset::make("DataType", @@ -151,15 +151,17 @@ TEST_SUITE_END() TEST_SUITE(QS16) // Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14 -FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", - DataType::QS16)), +FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataType", + DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output validate(Accessor(_target), _reference, tolerance_fixed_point); } -FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", - DataType::QS16)), +FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataType", + DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp index 690c4eac9e..1a082111a9 100644 --- a/tests/validation/Validation.cpp +++ b/tests/validation/Validation.cpp @@ -130,7 +130,7 @@ void check_border_element(const IAccessor &tensor, const Coordinates &id, const double target = get_double_data(ptr + channel_offset, tensor.data_type()); const double reference = get_double_data(static_cast<const uint8_t *>(border_value) + channel_offset, tensor.data_type()); - if(!compare<AbsoluteTolerance<double>, double>(target, reference)) + if(!compare<AbsoluteTolerance<double>>(target, reference)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << channel); @@ -192,7 +192,7 @@ void validate(const IAccessor &tensor, const void *reference_value) const double target = get_double_data(ptr + channel_offset, tensor.data_type()); const double reference = get_double_data(reference_value, tensor.data_type()); - if(!compare<AbsoluteTolerance<double>, double>(target, reference)) + if(!compare<AbsoluteTolerance<double>>(target, reference)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << channel); diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index e70c970cc1..6bc42a4ed6 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -226,11 +226,11 @@ struct compare_base T _tolerance{}; }; -template <typename T, typename U> +template <typename T> struct compare; template <typename U> -struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance<U>> +struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>> { using compare_base<AbsoluteTolerance<U>>::compare_base; @@ -245,12 +245,16 @@ struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance< return true; } - return static_cast<U>(std::abs(this->_target - this->_reference)) <= static_cast<U>(this->_tolerance); + using comparison_type = typename std::conditional<std::is_integral<U>::value, int64_t, U>::type; + + const comparison_type abs_difference(std::abs(static_cast<comparison_type>(this->_target) - static_cast<comparison_type>(this->_reference))); + + return abs_difference <= static_cast<comparison_type>(this->_tolerance); } }; template <typename U> -struct compare<RelativeTolerance<U>, U> : public compare_base<RelativeTolerance<U>> +struct compare<RelativeTolerance<U>> : public compare_base<RelativeTolerance<U>> { using compare_base<RelativeTolerance<U>>::compare_base; @@ -325,7 +329,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V const T &target_value = reinterpret_cast<const T *>(tensor(id))[c]; const T &reference_value = reinterpret_cast<const T *>(reference(id))[c]; - if(!compare<U, typename U::value_type>(target_value, reference_value, tolerance_value)) + if(!compare<U>(target_value, reference_value, tolerance_value)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << c); @@ -359,7 +363,7 @@ void validate(T target, T reference, U tolerance) ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference)); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target)); ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance))); - ARM_COMPUTE_EXPECT((compare<U, typename U::value_type>(target, reference, tolerance)), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((compare<U>(target, reference, tolerance)), framework::LogLevel::ERRORS); } } // namespace validation } // namespace test |