aboutsummaryrefslogtreecommitdiff
path: root/test/Lstm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/Lstm.hpp')
-rw-r--r--test/Lstm.hpp53
1 files changed, 14 insertions, 39 deletions
diff --git a/test/Lstm.hpp b/test/Lstm.hpp
index 2cb3c264..93f2f32d 100644
--- a/test/Lstm.hpp
+++ b/test/Lstm.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -9,8 +9,6 @@
#include <armnn/utility/IgnoreUnused.hpp>
-#include <boost/math/special_functions/relative_difference.hpp>
-
#include <array>
using ArmnnDriver = armnn_driver::ArmnnDriver;
@@ -40,26 +38,6 @@ RequestArgument CreateRequestArgument(const std::vector<T>& value, unsigned int
return inputRequestArgument;
}
-// Returns true if the relative difference between two float values is less than the tolerance value given.
-// This is used because the floating point comparison tolerance (set on each BOOST_AUTO_TEST_CASE) does not work!
-bool TolerantCompareEqual(float a, float b, float tolerance = 0.00001f)
-{
- float rd;
- if (a == 0.0f)
- {
- rd = fabs(b);
- }
- else if (b == 0.0f)
- {
- rd = fabs(a);
- }
- else
- {
- rd = boost::math::relative_difference(a, b);
- }
- return rd < tolerance;
-}
-
// Helper function to create an OperandLifeTime::NO_VALUE for testing.
// To be used on optional input operands that have no values - these are valid and should be tested.
V1_0::OperandLifeTime CreateNoValueLifeTime(const hidl_vec<uint32_t>& dimensions)
@@ -100,12 +78,6 @@ void ExecuteModel<armnn_driver::hal_1_2::HalPolicy::Model>(const armnn_driver::h
} // anonymous namespace
-#ifndef ARMCOMPUTECL_ENABLED
-static const std::array<armnn::Compute, 1> COMPUTE_DEVICES = {{ armnn::Compute::CpuRef }};
-#else
-static const std::array<armnn::Compute, 2> COMPUTE_DEVICES = {{ armnn::Compute::CpuRef, armnn::Compute::GpuAcc }};
-#endif
-
// Add our own tests here since we fail the lstm tests which Google supplies (because of non-const weights)
template <typename HalPolicy>
void LstmTestImpl(const hidl_vec<uint32_t>& inputDimensions,
@@ -394,18 +366,20 @@ void LstmTestImpl(const hidl_vec<uint32_t>& inputDimensions,
// check the results
for (size_t i = 0; i < outputStateOutValue.size(); ++i)
{
- BOOST_TEST(TolerantCompareEqual(outputStateOutValue[i], outputStateOutData[i]),
- "outputStateOut[" << i << "]: " << outputStateOutValue[i] << " != " << outputStateOutData[i]);
+ DOCTEST_CHECK_MESSAGE(outputStateOutValue[i] == doctest::Approx( outputStateOutData[i] ),
+ "outputStateOut[" << i << "]: " << outputStateOutValue[i] << " != "
+ << outputStateOutData[i]);
}
for (size_t i = 0; i < cellStateOutValue.size(); ++i)
{
- BOOST_TEST(TolerantCompareEqual(cellStateOutValue[i], cellStateOutData[i]),
- "cellStateOut[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]);
+ DOCTEST_CHECK_MESSAGE(cellStateOutValue[i] == doctest::Approx( cellStateOutData[i] ),
+ "cellStateOutValue[" << i << "]: " << cellStateOutValue[i] << " != "
+ << cellStateOutData[i]);
}
for (size_t i = 0; i < outputValue.size(); ++i)
{
- BOOST_TEST(TolerantCompareEqual(outputValue[i], outputData[i]),
- "output[" << i << "]: " << outputValue[i] << " != " << outputData[i]);
+ DOCTEST_CHECK_MESSAGE(outputValue[i] == doctest::Approx( outputData[i] ),
+ "outputValue[" << i << "]: " << outputValue[i] << " != " << outputData[i]);
}
}
@@ -669,13 +643,14 @@ void QuantizedLstmTestImpl(const hidl_vec<uint32_t>& inputDimensions,
// check the results
for (size_t i = 0; i < cellStateOutValue.size(); ++i)
{
- BOOST_TEST(TolerantCompareEqual(cellStateOutValue[i], cellStateOutData[i], 1.0f),
- "cellStateOut[" << i << "]: " << cellStateOutValue[i] << " != " << cellStateOutData[i]);
+ DOCTEST_CHECK_MESSAGE(cellStateOutValue[i] == doctest::Approx( cellStateOutData[i] ),
+ "cellStateOutValue[" << i << "]: " << cellStateOutValue[i] << " != "
+ << cellStateOutData[i]);
}
for (size_t i = 0; i < outputValue.size(); ++i)
{
- BOOST_TEST(TolerantCompareEqual(outputValue[i], outputData[i], 1.0f),
- "output[" << i << "]: " << outputValue[i] << " != " << outputData[i]);
+ DOCTEST_CHECK_MESSAGE(outputValue[i] == doctest::Approx( outputData[i] ),
+ "outputValue[" << i << "]: " << outputValue[i] << " != " << outputData[i]);
}
}