diff options
Diffstat (limited to 'src/armnn/test/TensorHandleStrategyTest.cpp')
-rw-r--r-- | src/armnn/test/TensorHandleStrategyTest.cpp | 57 |
1 files changed, 29 insertions, 28 deletions
diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp index 47d0666414..fb26880d0c 100644 --- a/src/armnn/test/TensorHandleStrategyTest.cpp +++ b/src/armnn/test/TensorHandleStrategyTest.cpp @@ -2,7 +2,8 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // -#include <boost/test/unit_test.hpp> + +#include <doctest/doctest.h> #include <armnn/LayerVisitorBase.hpp> @@ -270,29 +271,29 @@ private: }; -BOOST_AUTO_TEST_SUITE(TensorHandle) - -BOOST_AUTO_TEST_CASE(RegisterFactories) +TEST_SUITE("TensorHandle") +{ +TEST_CASE("RegisterFactories") { TestBackendA backendA; TestBackendB backendB; - BOOST_TEST(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1"); - BOOST_TEST(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2"); - BOOST_TEST(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1"); - BOOST_TEST(backendA.GetHandleFactoryPreferences()[3] == "TestHandleFactoryD1"); + CHECK(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1"); + CHECK(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2"); + CHECK(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1"); + CHECK(backendA.GetHandleFactoryPreferences()[3] == "TestHandleFactoryD1"); TensorHandleFactoryRegistry registry; backendA.RegisterTensorHandleFactories(registry); backendB.RegisterTensorHandleFactories(registry); - BOOST_TEST((registry.GetFactory("Non-existing Backend") == nullptr)); - BOOST_TEST((registry.GetFactory("TestHandleFactoryA1") != nullptr)); - BOOST_TEST((registry.GetFactory("TestHandleFactoryA2") != nullptr)); - BOOST_TEST((registry.GetFactory("TestHandleFactoryB1") != nullptr)); + CHECK((registry.GetFactory("Non-existing Backend") == nullptr)); + CHECK((registry.GetFactory("TestHandleFactoryA1") != nullptr)); + CHECK((registry.GetFactory("TestHandleFactoryA2") != nullptr)); + CHECK((registry.GetFactory("TestHandleFactoryB1") != nullptr)); } -BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) +TEST_CASE("TensorHandleSelectionStrategy") { auto backendA = std::make_unique<TestBackendA>(); auto backendB = std::make_unique<TestBackendB>(); @@ -343,8 +344,8 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) std::vector<std::string> errors; auto result = SelectTensorHandleStrategy(graph, backends, registry, true, errors); - BOOST_TEST(result.m_Error == false); - BOOST_TEST(result.m_Warning == false); + CHECK(result.m_Error == false); + CHECK(result.m_Warning == false); OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0); OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0); @@ -353,18 +354,18 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0); // Check that the correct factory was selected - BOOST_TEST(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); - BOOST_TEST(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); - BOOST_TEST(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); - BOOST_TEST(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1"); - BOOST_TEST(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); + CHECK(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); + CHECK(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); + CHECK(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); + CHECK(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1"); + CHECK(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); // Check that the correct strategy was selected - BOOST_TEST((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); - BOOST_TEST((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); - BOOST_TEST((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget)); - BOOST_TEST((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget)); - BOOST_TEST((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); + CHECK((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); + CHECK((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); + CHECK((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget)); + CHECK((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget)); + CHECK((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); graph.AddCompatibilityLayers(backends, registry); @@ -377,7 +378,7 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) copyCount++; } }); - BOOST_TEST(copyCount == 1); + CHECK(copyCount == 1); // Test for import layers int importCount= 0; @@ -388,7 +389,7 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) importCount++; } }); - BOOST_TEST(importCount == 1); + CHECK(importCount == 1); } -BOOST_AUTO_TEST_SUITE_END() +} |