diff options
Diffstat (limited to 'src/armnnSerializer/test/ComparisonSerializationTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/ComparisonSerializationTests.cpp | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/src/armnnSerializer/test/ComparisonSerializationTests.cpp b/src/armnnSerializer/test/ComparisonSerializationTests.cpp index 3aee9a7bcb..88778b306a 100644 --- a/src/armnnSerializer/test/ComparisonSerializationTests.cpp +++ b/src/armnnSerializer/test/ComparisonSerializationTests.cpp @@ -12,11 +12,10 @@ #include <armnnDeserializer/IDeserializer.hpp> #include <armnn/utility/IgnoreUnused.hpp> -#include <boost/test/unit_test.hpp> - - -BOOST_AUTO_TEST_SUITE(SerializerTests) +#include <doctest/doctest.h> +TEST_SUITE("SerializerTests") +{ struct ComparisonModel { ComparisonModel(const std::string& layerName, @@ -68,7 +67,7 @@ public: VerifyNameAndConnections(layer, name); const armnn::ComparisonDescriptor& layerDescriptor = static_cast<const armnn::ComparisonDescriptor&>(descriptor); - BOOST_CHECK(layerDescriptor.m_Operation == m_Descriptor.m_Operation); + CHECK(layerDescriptor.m_Operation == m_Descriptor.m_Operation); break; } default: @@ -82,7 +81,7 @@ private: armnn::ComparisonDescriptor m_Descriptor; }; -BOOST_AUTO_TEST_CASE(SerializeEqual) +TEST_CASE("SerializeEqual") { const std::string layerName("equal"); @@ -95,13 +94,13 @@ BOOST_AUTO_TEST_CASE(SerializeEqual) ComparisonModel model(layerName, inputInfo, outputInfo, descriptor); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor); deserializedNetwork->ExecuteStrategy(verifier); } -BOOST_AUTO_TEST_CASE(SerializeGreater) +TEST_CASE("SerializeGreater") { const std::string layerName("greater"); @@ -114,10 +113,10 @@ BOOST_AUTO_TEST_CASE(SerializeGreater) ComparisonModel model(layerName, inputInfo, outputInfo, descriptor); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor); deserializedNetwork->ExecuteStrategy(verifier); } -BOOST_AUTO_TEST_SUITE_END() +} |