diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2021-05-26 15:40:53 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2021-06-09 14:50:34 +0000 |
commit | 9150bff63a690caa743c471943afe509ebed1044 (patch) | |
tree | f98047d0a3a0e6cf06a4f34e0270a3cc7e3ee8bd /test/Lstm.hpp | |
parent | 07648f41d8b1fe9f532fa9c7e8e96a8e3651e59d (diff) | |
download | android-nn-driver-9150bff63a690caa743c471943afe509ebed1044.tar.gz |
IVGCVSW-4618 'Transition Units Test Suites'
* Used doctest in android-nn-driver unit tests.
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I9b5d4dfd77d53c7ebee7f8c43628a1d6ff74d1a3
Diffstat (limited to 'test/Lstm.hpp')
-rw-r--r-- | test/Lstm.hpp | 48 |
1 files changed, 11 insertions, 37 deletions
diff --git a/test/Lstm.hpp b/test/Lstm.hpp index 2cb3c264..e3844464 100644 --- a/test/Lstm.hpp +++ b/test/Lstm.hpp @@ -9,7 +9,7 @@ #include <armnn/utility/IgnoreUnused.hpp> -#include <boost/math/special_functions/relative_difference.hpp> +#include <doctest/doctest.h> #include <array> @@ -40,26 +40,6 @@ RequestArgument CreateRequestArgument(const std::vector<T>& value, unsigned int return inputRequestArgument; } -// Returns true if the relative difference between two float values is less than the tolerance value given. -// This is used because the floating point comparison tolerance (set on each BOOST_AUTO_TEST_CASE) does not work! -bool TolerantCompareEqual(float a, float b, float tolerance = 0.00001f) -{ - float rd; - if (a == 0.0f) - { - rd = fabs(b); - } - else if (b == 0.0f) - { - rd = fabs(a); - } - else - { - rd = boost::math::relative_difference(a, b); - } - return rd < tolerance; -} - // Helper function to create an OperandLifeTime::NO_VALUE for testing. // To be used on optional input operands that have no values - these are valid and should be tested. V1_0::OperandLifeTime CreateNoValueLifeTime(const hidl_vec<uint32_t>& dimensions) @@ -100,12 +80,6 @@ void ExecuteModel<armnn_driver::hal_1_2::HalPolicy::Model>(const armnn_driver::h } // anonymous namespace -#ifndef ARMCOMPUTECL_ENABLED -static const std::array<armnn::Compute, 1> COMPUTE_DEVICES = {{ armnn::Compute::CpuRef }}; -#else -static const std::array<armnn::Compute, 2> COMPUTE_DEVICES = {{ armnn::Compute::CpuRef, armnn::Compute::GpuAcc }}; -#endif - // Add our own tests here since we fail the lstm tests which Google supplies (because of non-const weights) template <typename HalPolicy> void LstmTestImpl(const hidl_vec<uint32_t>& inputDimensions, @@ -394,18 +368,18 @@ void LstmTestImpl(const hidl_vec<uint32_t>& inputDimensions, // check the results for (size_t i = 0; i < outputStateOutValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(outputStateOutValue[i], outputStateOutData[i]), - "outputStateOut[" << i << "]: " << outputStateOutValue[i] << " != " << outputStateOutData[i]); + CHECK_MESSAGE(outputStateOutValue[i] == doctest::Approx( outputStateOutData[i] ), + "outputStateOut[" << i << "]: " << outputStateOutValue[i] << " != " << outputStateOutData[i]); } for (size_t i = 0; i < cellStateOutValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(cellStateOutValue[i], cellStateOutData[i]), - "cellStateOut[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]); + CHECK_MESSAGE(cellStateOutValue[i] == doctest::Approx( cellStateOutData[i] ), + "cellStateOutValue[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]); } for (size_t i = 0; i < outputValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(outputValue[i], outputData[i]), - "output[" << i << "]: " << outputValue[i] << " != " << outputData[i]); + CHECK_MESSAGE(outputValue[i] == doctest::Approx( outputData[i] ), + "outputValue[" << i << "]: " << outputValue[i] << " != " << outputData[i]); } } @@ -669,13 +643,13 @@ void QuantizedLstmTestImpl(const hidl_vec<uint32_t>& inputDimensions, // check the results for (size_t i = 0; i < cellStateOutValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(cellStateOutValue[i], cellStateOutData[i], 1.0f), - "cellStateOut[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]); + CHECK_MESSAGE(cellStateOutValue[i] == doctest::Approx( cellStateOutData[i] ), + "cellStateOutValue[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]); } for (size_t i = 0; i < outputValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(outputValue[i], outputData[i], 1.0f), - "output[" << i << "]: " << outputValue[i] << " != " << outputData[i]); + CHECK_MESSAGE(outputValue[i] == doctest::Approx( outputData[i] ), + "outputValue[" << i << "]: " << outputValue[i] << " != " << outputData[i]); } } |