aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp')
-rw-r--r--src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp17
1 files changed, 9 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp
index 11003a2e97..035c592738 100644
--- a/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/LstmTestImpl.cpp
@@ -20,6 +20,7 @@
#include <test/TensorHelpers.hpp>
+#include <doctest/doctest.h>
namespace
{
@@ -45,11 +46,11 @@ void LstmUtilsVectorBatchVectorAddTestImpl(
// check shape and compare values
auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
- BOOST_TEST(result.m_Result, result.m_Message.str());
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
// check if iterator is back at start position
batchVecEncoder->Set(1.0f);
- BOOST_TEST(batchVec[0] == 1.0f);
+ CHECK(batchVec[0] == 1.0f);
}
template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
@@ -72,11 +73,11 @@ void LstmUtilsZeroVectorTestImpl(
// check shape and compare values
auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
- BOOST_TEST(result.m_Result, result.m_Message.str());
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
// check if iterator is back at start position
outputEncoder->Set(1.0f);
- BOOST_TEST(input[0] == 1.0f);
+ CHECK(input[0] == 1.0f);
}
@@ -100,11 +101,11 @@ void LstmUtilsMeanStddevNormalizationTestImpl(
// check shape and compare values
auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
- BOOST_TEST(result.m_Result, result.m_Message.str());
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
// check if iterator is back at start position
outputEncoder->Set(1.0f);
- BOOST_TEST(input[0] == 1.0f);
+ CHECK(input[0] == 1.0f);
}
template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
@@ -129,11 +130,11 @@ void LstmUtilsVectorBatchVectorCwiseProductTestImpl(
// check shape and compare values
auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
- BOOST_TEST(result.m_Result, result.m_Message.str());
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
// check if iterator is back at start position
batchVecEncoder->Set(1.0f);
- BOOST_TEST(batchVec[0] == 1.0f);
+ CHECK(batchVec[0] == 1.0f);
}
// Lstm Layer tests: