aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TensorHandleStrategyTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/TensorHandleStrategyTest.cpp')
-rw-r--r--src/armnn/test/TensorHandleStrategyTest.cpp57
1 files changed, 29 insertions, 28 deletions
diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp
index 47d0666414..fb26880d0c 100644
--- a/src/armnn/test/TensorHandleStrategyTest.cpp
+++ b/src/armnn/test/TensorHandleStrategyTest.cpp
@@ -2,7 +2,8 @@
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include <boost/test/unit_test.hpp>
+
+#include <doctest/doctest.h>
#include <armnn/LayerVisitorBase.hpp>
@@ -270,29 +271,29 @@ private:
};
-BOOST_AUTO_TEST_SUITE(TensorHandle)
-
-BOOST_AUTO_TEST_CASE(RegisterFactories)
+TEST_SUITE("TensorHandle")
+{
+TEST_CASE("RegisterFactories")
{
TestBackendA backendA;
TestBackendB backendB;
- BOOST_TEST(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1");
- BOOST_TEST(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2");
- BOOST_TEST(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1");
- BOOST_TEST(backendA.GetHandleFactoryPreferences()[3] == "TestHandleFactoryD1");
+ CHECK(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1");
+ CHECK(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2");
+ CHECK(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1");
+ CHECK(backendA.GetHandleFactoryPreferences()[3] == "TestHandleFactoryD1");
TensorHandleFactoryRegistry registry;
backendA.RegisterTensorHandleFactories(registry);
backendB.RegisterTensorHandleFactories(registry);
- BOOST_TEST((registry.GetFactory("Non-existing Backend") == nullptr));
- BOOST_TEST((registry.GetFactory("TestHandleFactoryA1") != nullptr));
- BOOST_TEST((registry.GetFactory("TestHandleFactoryA2") != nullptr));
- BOOST_TEST((registry.GetFactory("TestHandleFactoryB1") != nullptr));
+ CHECK((registry.GetFactory("Non-existing Backend") == nullptr));
+ CHECK((registry.GetFactory("TestHandleFactoryA1") != nullptr));
+ CHECK((registry.GetFactory("TestHandleFactoryA2") != nullptr));
+ CHECK((registry.GetFactory("TestHandleFactoryB1") != nullptr));
}
-BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
+TEST_CASE("TensorHandleSelectionStrategy")
{
auto backendA = std::make_unique<TestBackendA>();
auto backendB = std::make_unique<TestBackendB>();
@@ -343,8 +344,8 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
std::vector<std::string> errors;
auto result = SelectTensorHandleStrategy(graph, backends, registry, true, errors);
- BOOST_TEST(result.m_Error == false);
- BOOST_TEST(result.m_Warning == false);
+ CHECK(result.m_Error == false);
+ CHECK(result.m_Warning == false);
OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0);
OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0);
@@ -353,18 +354,18 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0);
// Check that the correct factory was selected
- BOOST_TEST(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
- BOOST_TEST(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
- BOOST_TEST(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
- BOOST_TEST(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1");
- BOOST_TEST(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
+ CHECK(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
+ CHECK(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
+ CHECK(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
+ CHECK(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1");
+ CHECK(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
// Check that the correct strategy was selected
- BOOST_TEST((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
- BOOST_TEST((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
- BOOST_TEST((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget));
- BOOST_TEST((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget));
- BOOST_TEST((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
+ CHECK((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
+ CHECK((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
+ CHECK((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget));
+ CHECK((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget));
+ CHECK((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
graph.AddCompatibilityLayers(backends, registry);
@@ -377,7 +378,7 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
copyCount++;
}
});
- BOOST_TEST(copyCount == 1);
+ CHECK(copyCount == 1);
// Test for import layers
int importCount= 0;
@@ -388,7 +389,7 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
importCount++;
}
});
- BOOST_TEST(importCount == 1);
+ CHECK(importCount == 1);
}
-BOOST_AUTO_TEST_SUITE_END()
+}