aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test/RefPerAxisIteratorTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test/RefPerAxisIteratorTests.cpp')
-rw-r--r--src/backends/reference/test/RefPerAxisIteratorTests.cpp86
1 files changed, 41 insertions, 45 deletions
diff --git a/src/backends/reference/test/RefPerAxisIteratorTests.cpp b/src/backends/reference/test/RefPerAxisIteratorTests.cpp
index 7da4c0fb0f..06d2703b4e 100644
--- a/src/backends/reference/test/RefPerAxisIteratorTests.cpp
+++ b/src/backends/reference/test/RefPerAxisIteratorTests.cpp
@@ -4,37 +4,35 @@
//
#include <reference/workloads/Decoders.hpp>
-#include <armnn/utility/NumericCast.hpp>
#include <fmt/format.h>
-#include <boost/test/unit_test.hpp>
-#include <chrono>
+#include <doctest/doctest.h>
+#include <chrono>
template<typename T>
void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
{
- BOOST_TEST(vec1.size() == vec2.size());
+ CHECK(vec1.size() == vec2.size());
bool mismatch = false;
for (uint i = 0; i < vec1.size(); ++i)
{
if (vec1[i] != vec2[i])
{
- /*std::stringstream ss;
- ss << "Vector value mismatch: index=" << i << " " << vec1[i] << "!=" << vec2[i];*/
- BOOST_TEST_MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
- i,
- vec1[i],
- vec2[i]));
+ MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
+ i,
+ vec1[i],
+ vec2[i]));
+
mismatch = true;
}
}
if (mismatch)
{
- BOOST_FAIL("Error in CompareVector. Vectors don't match.");
+ FAIL("Error in CompareVector. Vectors don't match.");
}
}
@@ -79,10 +77,10 @@ public:
unsigned int m_NumElements;
};
-BOOST_AUTO_TEST_SUITE(RefPerAxisIterator)
-
+TEST_SUITE("RefPerAxisIterator")
+{
// Test Loop (Equivalent to DecodeTensor) and Axis = 0
-BOOST_AUTO_TEST_CASE(PerAxisIteratorTest1)
+TEST_CASE("PerAxisIteratorTest1")
{
std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
@@ -95,17 +93,17 @@ BOOST_AUTO_TEST_CASE(PerAxisIteratorTest1)
// Set iterator to index and check if the axis index is correct
iterator[5];
- BOOST_TEST(iterator.GetAxisIndex() == 1u);
+ CHECK(iterator.GetAxisIndex() == 1u);
iterator[1];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
iterator[10];
- BOOST_TEST(iterator.GetAxisIndex() == 2u);
+ CHECK(iterator.GetAxisIndex() == 2u);
}
// Test Axis = 1
-BOOST_AUTO_TEST_CASE(PerAxisIteratorTest2)
+TEST_CASE("PerAxisIteratorTest2")
{
std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
@@ -118,17 +116,17 @@ BOOST_AUTO_TEST_CASE(PerAxisIteratorTest2)
// Set iterator to index and check if the axis index is correct
iterator[5];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
iterator[1];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
iterator[10];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
}
// Test Axis = 2
-BOOST_AUTO_TEST_CASE(PerAxisIteratorTest3)
+TEST_CASE("PerAxisIteratorTest3")
{
std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
@@ -141,17 +139,17 @@ BOOST_AUTO_TEST_CASE(PerAxisIteratorTest3)
// Set iterator to index and check if the axis index is correct
iterator[5];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
iterator[1];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
iterator[10];
- BOOST_TEST(iterator.GetAxisIndex() == 1u);
+ CHECK(iterator.GetAxisIndex() == 1u);
}
// Test Axis = 3
-BOOST_AUTO_TEST_CASE(PerAxisIteratorTest4)
+TEST_CASE("PerAxisIteratorTest4")
{
std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
@@ -164,18 +162,17 @@ BOOST_AUTO_TEST_CASE(PerAxisIteratorTest4)
// Set iterator to index and check if the axis index is correct
iterator[5];
- BOOST_TEST(iterator.GetAxisIndex() == 1u);
+ CHECK(iterator.GetAxisIndex() == 1u);
iterator[1];
- BOOST_TEST(iterator.GetAxisIndex() == 1u);
+ CHECK(iterator.GetAxisIndex() == 1u);
iterator[10];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
}
-
// Test Axis = 1. Different tensor shape
-BOOST_AUTO_TEST_CASE(PerAxisIteratorTest5)
+TEST_CASE("PerAxisIteratorTest5")
{
using namespace armnn;
std::vector<int8_t> input =
@@ -201,17 +198,17 @@ BOOST_AUTO_TEST_CASE(PerAxisIteratorTest5)
// Set iterator to index and check if the axis index is correct
iterator[5];
- BOOST_TEST(iterator.GetAxisIndex() == 1u);
+ CHECK(iterator.GetAxisIndex() == 1u);
iterator[1];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
iterator[10];
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.GetAxisIndex() == 0u);
}
// Test the increment and decrement operator
-BOOST_AUTO_TEST_CASE(PerAxisIteratorTest7)
+TEST_CASE("PerAxisIteratorTest7")
{
using namespace armnn;
std::vector<int8_t> input =
@@ -232,21 +229,20 @@ BOOST_AUTO_TEST_CASE(PerAxisIteratorTest7)
auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
iterator += 3;
- BOOST_TEST(iterator.Get(), expOutput[3]);
- BOOST_TEST(iterator.GetAxisIndex() == 1u);
+ CHECK(iterator.Get() == expOutput[3]);
+ CHECK(iterator.GetAxisIndex() == 1u);
iterator += 3;
- BOOST_TEST(iterator.Get(), expOutput[6]);
- BOOST_TEST(iterator.GetAxisIndex() == 1u);
+ CHECK(iterator.Get() == expOutput[6]);
+ CHECK(iterator.GetAxisIndex() == 1u);
iterator -= 2;
- BOOST_TEST(iterator.Get(), expOutput[4]);
- BOOST_TEST(iterator.GetAxisIndex() == 0u);
+ CHECK(iterator.Get() == expOutput[4]);
+ CHECK(iterator.GetAxisIndex() == 0u);
iterator -= 1;
- BOOST_TEST(iterator.Get(), expOutput[3]);
- BOOST_TEST(iterator.GetAxisIndex() == 1u);
+ CHECK(iterator.Get() == expOutput[3]);
+ CHECK(iterator.GetAxisIndex() == 1u);
}
-
-BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file
+} \ No newline at end of file