diff options
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 0689114074..a18ae32a03 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1334,4 +1334,44 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeStridedSlice) {outputTensorInfo.GetShape()}); } +BOOST_AUTO_TEST_CASE(SerializeDeserializeMean) +{ + class VerifyMeanName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> + { + public: + void VisitMeanLayer(const armnn::IConnectableLayer*, const armnn::MeanDescriptor&, const char* name) + { + BOOST_TEST(name == "mean"); + } + }; + + armnn::TensorInfo inputTensorInfo({1, 1, 3, 2}, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({1, 1, 1, 2}, armnn::DataType::Float32); + + armnn::MeanDescriptor descriptor; + descriptor.m_Axis = { 2 }; + descriptor.m_KeepDims = true; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const meanLayer = network->AddMeanLayer(descriptor, "mean"); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(meanLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + meanLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + meanLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyMeanName nameChecker; + deserializedNetwork->Accept(nameChecker); + + CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork, + *network, + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape()}); +} + BOOST_AUTO_TEST_SUITE_END() |