diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2019-03-04 17:44:21 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2019-03-07 09:24:37 +0000 |
commit | ac97c8cda28f81ce76834b8b769967d42b02e2ac (patch) | |
tree | 7bf1b1104a4bfb041ee5d74cabe1e1be2664af70 /src/armnnSerializer/test | |
parent | c192f35edefd0977396db8d381adc7598e3660cc (diff) | |
download | armnn-ac97c8cda28f81ce76834b8b769967d42b02e2ac.tar.gz |
IVGCVSW-2696 Serialize / de-serialize the Mean layer
Change-Id: Iee4bab5a6d6b992cf4bba8697a2918f854c906a3
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Diffstat (limited to 'src/armnnSerializer/test')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 0689114074..a18ae32a03 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1334,4 +1334,44 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeStridedSlice) {outputTensorInfo.GetShape()}); } +BOOST_AUTO_TEST_CASE(SerializeDeserializeMean) +{ + class VerifyMeanName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> + { + public: + void VisitMeanLayer(const armnn::IConnectableLayer*, const armnn::MeanDescriptor&, const char* name) + { + BOOST_TEST(name == "mean"); + } + }; + + armnn::TensorInfo inputTensorInfo({1, 1, 3, 2}, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({1, 1, 1, 2}, armnn::DataType::Float32); + + armnn::MeanDescriptor descriptor; + descriptor.m_Axis = { 2 }; + descriptor.m_KeepDims = true; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const meanLayer = network->AddMeanLayer(descriptor, "mean"); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(meanLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + meanLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + meanLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyMeanName nameChecker; + deserializedNetwork->Accept(nameChecker); + + CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork, + *network, + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape()}); +} + BOOST_AUTO_TEST_SUITE_END() |