aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/LstmSerializationTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/test/LstmSerializationTests.cpp')
-rw-r--r--src/armnnSerializer/test/LstmSerializationTests.cpp41
1 files changed, 20 insertions, 21 deletions
diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp
index 4705c0bd28..c2bc8737b4 100644
--- a/src/armnnSerializer/test/LstmSerializationTests.cpp
+++ b/src/armnnSerializer/test/LstmSerializationTests.cpp
@@ -14,13 +14,12 @@
#include <armnn/LstmParams.hpp>
#include <armnn/QuantizedLstmParams.hpp>
-#include <boost/test/unit_test.hpp>
-
+#include <doctest/doctest.h>
#include <fmt/format.h>
-BOOST_AUTO_TEST_SUITE(SerializerTests)
-
+TEST_SUITE("SerializerTests")
+{
template<typename Descriptor>
armnn::LstmInputParams ConstantVector2LstmInputParams(const std::vector<armnn::ConstTensor>& constants,
Descriptor& descriptor)
@@ -175,7 +174,7 @@ private:
armnn::LstmInputParams m_InputParams;
};
-BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection)
+TEST_CASE("SerializeDeserializeLstmCifgPeepholeNoProjection")
{
armnn::LstmDescriptor descriptor;
descriptor.m_ActivationFunc = 4;
@@ -278,7 +277,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection)
lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
+ CHECK(deserializedNetwork);
VerifyLstmLayer<armnn::LstmDescriptor> checker(
layerName,
@@ -289,7 +288,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection)
deserializedNetwork->ExecuteStrategy(checker);
}
-BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection)
+TEST_CASE("SerializeDeserializeLstmNoCifgWithPeepholeAndProjection")
{
armnn::LstmDescriptor descriptor;
descriptor.m_ActivationFunc = 4;
@@ -424,7 +423,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection)
lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
+ CHECK(deserializedNetwork);
VerifyLstmLayer<armnn::LstmDescriptor> checker(
layerName,
@@ -435,7 +434,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection)
deserializedNetwork->ExecuteStrategy(checker);
}
-BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeWithProjectionWithLayerNorm)
+TEST_CASE("SerializeDeserializeLstmNoCifgWithPeepholeWithProjectionWithLayerNorm")
{
armnn::LstmDescriptor descriptor;
descriptor.m_ActivationFunc = 4;
@@ -589,7 +588,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeWithProjectionWit
lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
+ CHECK(deserializedNetwork);
VerifyLstmLayer<armnn::LstmDescriptor> checker(
layerName,
@@ -600,7 +599,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeWithProjectionWit
deserializedNetwork->ExecuteStrategy(checker);
}
-BOOST_AUTO_TEST_CASE(EnsureLstmLayersBackwardCompatibility)
+TEST_CASE("EnsureLstmLayersBackwardCompatibility")
{
// The hex data below is a flat buffer containing a lstm layer with no Cifg, with peephole and projection
// enabled. That data was obtained before additional layer normalization parameters where added to the
@@ -1220,7 +1219,7 @@ BOOST_AUTO_TEST_CASE(EnsureLstmLayersBackwardCompatibility)
DeserializeNetwork(std::string(lstmNoCifgWithPeepholeAndProjectionModel.begin(),
lstmNoCifgWithPeepholeAndProjectionModel.end()));
- BOOST_CHECK(deserializedNetwork);
+ CHECK(deserializedNetwork);
// generating the same model parameters which where used to serialize the model (Layer norm is not specified)
armnn::LstmDescriptor descriptor;
@@ -1428,7 +1427,7 @@ private:
armnn::QuantizedLstmInputParams m_InputParams;
};
-BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm)
+TEST_CASE("SerializeDeserializeQuantizedLstm")
{
const uint32_t batchSize = 1;
const uint32_t inputSize = 2;
@@ -1600,7 +1599,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm)
quantizedLstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
+ CHECK(deserializedNetwork);
VerifyQuantizedLstmLayer checker(layerName,
{inputTensorInfo, cellStateTensorInfo, outputStateTensorInfo},
@@ -1610,7 +1609,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm)
deserializedNetwork->ExecuteStrategy(checker);
}
-BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic)
+TEST_CASE("SerializeDeserializeQLstmBasic")
{
armnn::QLstmDescriptor descriptor;
@@ -1755,7 +1754,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic)
qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
+ CHECK(deserializedNetwork);
VerifyLstmLayer<armnn::QLstmDescriptor> checker(
layerName,
@@ -1767,7 +1766,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic)
deserializedNetwork->ExecuteStrategy(checker);
}
-BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm)
+TEST_CASE("SerializeDeserializeQLstmCifgLayerNorm")
{
armnn::QLstmDescriptor descriptor;
@@ -1944,7 +1943,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm)
qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
+ CHECK(deserializedNetwork);
VerifyLstmLayer<armnn::QLstmDescriptor> checker(layerName,
{inputInfo, cellStateInfo, outputStateInfo},
@@ -1955,7 +1954,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm)
deserializedNetwork->ExecuteStrategy(checker);
}
-BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced)
+TEST_CASE("SerializeDeserializeQLstmAdvanced")
{
armnn::QLstmDescriptor descriptor;
@@ -2185,7 +2184,7 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced)
qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
+ CHECK(deserializedNetwork);
VerifyLstmLayer<armnn::QLstmDescriptor> checker(layerName,
{inputInfo, cellStateInfo, outputStateInfo},
@@ -2196,4 +2195,4 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced)
deserializedNetwork->ExecuteStrategy(checker);
}
-BOOST_AUTO_TEST_SUITE_END()
+}