diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-09-09 17:16:24 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2019-09-11 08:57:05 +0000 |
commit | 0cfcf235c4bcd2ae570eea8bc2677f471281b8e6 (patch) | |
tree | dbeba0ddeb240f91099f6d436757cd365dc8ab2b /src/armnnSerializer/test/SerializerTests.cpp | |
parent | 4cd29a046c3d46917d84d12feb668969af23a39e (diff) | |
download | armnn-0cfcf235c4bcd2ae570eea8bc2677f471281b8e6.tar.gz |
IVGCVSW-3724 Adding serialization support for ArgMinMax
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I21210c843c3b8800ccc68d4f3095259d0a233bd1
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index b5ef8c6b4e..bbd5402c7c 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -288,6 +288,62 @@ BOOST_AUTO_TEST_CASE(SerializeAddition) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeArgMinMax) +{ + class ArgMinMaxLayerVerifier : public LayerVerifierBase + { + public: + ArgMinMaxLayerVerifier(const std::string& layerName, + const std::vector<armnn::TensorInfo>& inputInfos, + const std::vector<armnn::TensorInfo>& outputInfos, + const armnn::ArgMinMaxDescriptor& descriptor) + : LayerVerifierBase(layerName, inputInfos, outputInfos) + , m_Descriptor(descriptor) {} + + void VisitArgMinMaxLayer(const armnn::IConnectableLayer* layer, + const armnn::ArgMinMaxDescriptor& descriptor, + const char* name) override + { + VerifyNameAndConnections(layer, name); + VerifyDescriptor(descriptor); + } + + private: + void VerifyDescriptor(const armnn::ArgMinMaxDescriptor& descriptor) + { + BOOST_CHECK(descriptor.m_Function == m_Descriptor.m_Function); + BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis); + } + + armnn::ArgMinMaxDescriptor m_Descriptor; + }; + + const std::string layerName("argminmax"); + const armnn::TensorInfo inputInfo({1, 2, 3}, armnn::DataType::Float32); + const armnn::TensorInfo outputInfo({1, 3}, armnn::DataType::Signed32); + + armnn::ArgMinMaxDescriptor descriptor; + descriptor.m_Function = armnn::ArgMinMaxFunction::Max; + descriptor.m_Axis = 1; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const argMinMaxLayer = network->AddArgMinMaxLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(argMinMaxLayer->GetInputSlot(0)); + argMinMaxLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); + argMinMaxLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + ArgMinMaxLayerVerifier verifier(layerName, {inputInfo}, {outputInfo}, descriptor); + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeBatchNormalization) { class BatchNormalizationLayerVerifier : public LayerVerifierBase |