aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/test')
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp43
1 files changed, 43 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index f4e25998d9..632a80a748 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1834,6 +1834,49 @@ TEST_CASE("SerializePooling2d")
deserializedNetwork->ExecuteStrategy(verifier);
}
+TEST_CASE("SerializePooling3d")
+{
+ const std::string layerName("pooling3d");
+ const armnn::TensorInfo inputInfo({1, 1, 2, 2, 2}, armnn::DataType::Float32);
+ const armnn::TensorInfo outputInfo({1, 1, 1, 1, 1}, armnn::DataType::Float32);
+
+ armnn::Pooling3dDescriptor desc;
+ desc.m_DataLayout = armnn::DataLayout::NDHWC;
+ desc.m_PadFront = 0;
+ desc.m_PadBack = 0;
+ desc.m_PadTop = 0;
+ desc.m_PadBottom = 0;
+ desc.m_PadLeft = 0;
+ desc.m_PadRight = 0;
+ desc.m_PoolType = armnn::PoolingAlgorithm::Average;
+ desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
+ desc.m_PaddingMethod = armnn::PaddingMethod::Exclude;
+ desc.m_PoolHeight = 2;
+ desc.m_PoolWidth = 2;
+ desc.m_PoolDepth = 2;
+ desc.m_StrideX = 2;
+ desc.m_StrideY = 2;
+ desc.m_StrideZ = 2;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const pooling3dLayer = network->AddPooling3dLayer(desc, layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ inputLayer->GetOutputSlot(0).Connect(pooling3dLayer->GetInputSlot(0));
+ pooling3dLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
+ pooling3dLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ CHECK(deserializedNetwork);
+
+ LayerVerifierBaseWithDescriptor<armnn::Pooling3dDescriptor> verifier(
+ layerName, {inputInfo}, {outputInfo}, desc);
+ deserializedNetwork->ExecuteStrategy(verifier);
+}
+
TEST_CASE("SerializeQuantize")
{
const std::string layerName("quantize");