From 9150bff63a690caa743c471943afe509ebed1044 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 26 May 2021 15:40:53 +0100 Subject: IVGCVSW-4618 'Transition Units Test Suites' * Used doctest in android-nn-driver unit tests. Signed-off-by: Sadik Armagan Change-Id: I9b5d4dfd77d53c7ebee7f8c43628a1d6ff74d1a3 --- test/Lstm.hpp | 48 +++++++++++------------------------------------- 1 file changed, 11 insertions(+), 37 deletions(-) (limited to 'test/Lstm.hpp') diff --git a/test/Lstm.hpp b/test/Lstm.hpp index 2cb3c264..e3844464 100644 --- a/test/Lstm.hpp +++ b/test/Lstm.hpp @@ -9,7 +9,7 @@ #include -#include +#include #include @@ -40,26 +40,6 @@ RequestArgument CreateRequestArgument(const std::vector& value, unsigned int return inputRequestArgument; } -// Returns true if the relative difference between two float values is less than the tolerance value given. -// This is used because the floating point comparison tolerance (set on each BOOST_AUTO_TEST_CASE) does not work! -bool TolerantCompareEqual(float a, float b, float tolerance = 0.00001f) -{ - float rd; - if (a == 0.0f) - { - rd = fabs(b); - } - else if (b == 0.0f) - { - rd = fabs(a); - } - else - { - rd = boost::math::relative_difference(a, b); - } - return rd < tolerance; -} - // Helper function to create an OperandLifeTime::NO_VALUE for testing. // To be used on optional input operands that have no values - these are valid and should be tested. V1_0::OperandLifeTime CreateNoValueLifeTime(const hidl_vec& dimensions) @@ -100,12 +80,6 @@ void ExecuteModel(const armnn_driver::h } // anonymous namespace -#ifndef ARMCOMPUTECL_ENABLED -static const std::array COMPUTE_DEVICES = {{ armnn::Compute::CpuRef }}; -#else -static const std::array COMPUTE_DEVICES = {{ armnn::Compute::CpuRef, armnn::Compute::GpuAcc }}; -#endif - // Add our own tests here since we fail the lstm tests which Google supplies (because of non-const weights) template void LstmTestImpl(const hidl_vec& inputDimensions, @@ -394,18 +368,18 @@ void LstmTestImpl(const hidl_vec& inputDimensions, // check the results for (size_t i = 0; i < outputStateOutValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(outputStateOutValue[i], outputStateOutData[i]), - "outputStateOut[" << i << "]: " << outputStateOutValue[i] << " != " << outputStateOutData[i]); + CHECK_MESSAGE(outputStateOutValue[i] == doctest::Approx( outputStateOutData[i] ), + "outputStateOut[" << i << "]: " << outputStateOutValue[i] << " != " << outputStateOutData[i]); } for (size_t i = 0; i < cellStateOutValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(cellStateOutValue[i], cellStateOutData[i]), - "cellStateOut[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]); + CHECK_MESSAGE(cellStateOutValue[i] == doctest::Approx( cellStateOutData[i] ), + "cellStateOutValue[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]); } for (size_t i = 0; i < outputValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(outputValue[i], outputData[i]), - "output[" << i << "]: " << outputValue[i] << " != " << outputData[i]); + CHECK_MESSAGE(outputValue[i] == doctest::Approx( outputData[i] ), + "outputValue[" << i << "]: " << outputValue[i] << " != " << outputData[i]); } } @@ -669,13 +643,13 @@ void QuantizedLstmTestImpl(const hidl_vec& inputDimensions, // check the results for (size_t i = 0; i < cellStateOutValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(cellStateOutValue[i], cellStateOutData[i], 1.0f), - "cellStateOut[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]); + CHECK_MESSAGE(cellStateOutValue[i] == doctest::Approx( cellStateOutData[i] ), + "cellStateOutValue[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]); } for (size_t i = 0; i < outputValue.size(); ++i) { - BOOST_TEST(TolerantCompareEqual(outputValue[i], outputData[i], 1.0f), - "output[" << i << "]: " << outputValue[i] << " != " << outputData[i]); + CHECK_MESSAGE(outputValue[i] == doctest::Approx( outputData[i] ), + "outputValue[" << i << "]: " << outputValue[i] << " != " << outputData[i]); } } -- cgit v1.2.1