diff options
Diffstat (limited to 'src/armnnSerializer/test/LstmSerializationTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/LstmSerializationTests.cpp | 41 |
1 files changed, 20 insertions, 21 deletions
diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp index 4705c0bd28..c2bc8737b4 100644 --- a/src/armnnSerializer/test/LstmSerializationTests.cpp +++ b/src/armnnSerializer/test/LstmSerializationTests.cpp @@ -14,13 +14,12 @@ #include <armnn/LstmParams.hpp> #include <armnn/QuantizedLstmParams.hpp> -#include <boost/test/unit_test.hpp> - +#include <doctest/doctest.h> #include <fmt/format.h> -BOOST_AUTO_TEST_SUITE(SerializerTests) - +TEST_SUITE("SerializerTests") +{ template<typename Descriptor> armnn::LstmInputParams ConstantVector2LstmInputParams(const std::vector<armnn::ConstTensor>& constants, Descriptor& descriptor) @@ -175,7 +174,7 @@ private: armnn::LstmInputParams m_InputParams; }; -BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection) +TEST_CASE("SerializeDeserializeLstmCifgPeepholeNoProjection") { armnn::LstmDescriptor descriptor; descriptor.m_ActivationFunc = 4; @@ -278,7 +277,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection) lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); VerifyLstmLayer<armnn::LstmDescriptor> checker( layerName, @@ -289,7 +288,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection) deserializedNetwork->ExecuteStrategy(checker); } -BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection) +TEST_CASE("SerializeDeserializeLstmNoCifgWithPeepholeAndProjection") { armnn::LstmDescriptor descriptor; descriptor.m_ActivationFunc = 4; @@ -424,7 +423,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection) lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); VerifyLstmLayer<armnn::LstmDescriptor> checker( layerName, @@ -435,7 +434,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection) deserializedNetwork->ExecuteStrategy(checker); } -BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeWithProjectionWithLayerNorm) +TEST_CASE("SerializeDeserializeLstmNoCifgWithPeepholeWithProjectionWithLayerNorm") { armnn::LstmDescriptor descriptor; descriptor.m_ActivationFunc = 4; @@ -589,7 +588,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeWithProjectionWit lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); VerifyLstmLayer<armnn::LstmDescriptor> checker( layerName, @@ -600,7 +599,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeWithProjectionWit deserializedNetwork->ExecuteStrategy(checker); } -BOOST_AUTO_TEST_CASE(EnsureLstmLayersBackwardCompatibility) +TEST_CASE("EnsureLstmLayersBackwardCompatibility") { // The hex data below is a flat buffer containing a lstm layer with no Cifg, with peephole and projection // enabled. That data was obtained before additional layer normalization parameters where added to the @@ -1220,7 +1219,7 @@ BOOST_AUTO_TEST_CASE(EnsureLstmLayersBackwardCompatibility) DeserializeNetwork(std::string(lstmNoCifgWithPeepholeAndProjectionModel.begin(), lstmNoCifgWithPeepholeAndProjectionModel.end())); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); // generating the same model parameters which where used to serialize the model (Layer norm is not specified) armnn::LstmDescriptor descriptor; @@ -1428,7 +1427,7 @@ private: armnn::QuantizedLstmInputParams m_InputParams; }; -BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm) +TEST_CASE("SerializeDeserializeQuantizedLstm") { const uint32_t batchSize = 1; const uint32_t inputSize = 2; @@ -1600,7 +1599,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm) quantizedLstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); VerifyQuantizedLstmLayer checker(layerName, {inputTensorInfo, cellStateTensorInfo, outputStateTensorInfo}, @@ -1610,7 +1609,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm) deserializedNetwork->ExecuteStrategy(checker); } -BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic) +TEST_CASE("SerializeDeserializeQLstmBasic") { armnn::QLstmDescriptor descriptor; @@ -1755,7 +1754,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic) qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); VerifyLstmLayer<armnn::QLstmDescriptor> checker( layerName, @@ -1767,7 +1766,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic) deserializedNetwork->ExecuteStrategy(checker); } -BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm) +TEST_CASE("SerializeDeserializeQLstmCifgLayerNorm") { armnn::QLstmDescriptor descriptor; @@ -1944,7 +1943,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm) qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); VerifyLstmLayer<armnn::QLstmDescriptor> checker(layerName, {inputInfo, cellStateInfo, outputStateInfo}, @@ -1955,7 +1954,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm) deserializedNetwork->ExecuteStrategy(checker); } -BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced) +TEST_CASE("SerializeDeserializeQLstmAdvanced") { armnn::QLstmDescriptor descriptor; @@ -2185,7 +2184,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced) qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); - BOOST_CHECK(deserializedNetwork); + CHECK(deserializedNetwork); VerifyLstmLayer<armnn::QLstmDescriptor> checker(layerName, {inputInfo, cellStateInfo, outputStateInfo}, @@ -2196,4 +2195,4 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced) deserializedNetwork->ExecuteStrategy(checker); } -BOOST_AUTO_TEST_SUITE_END() +} |