diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-04-25 17:14:50 +0100 |
---|---|---|
committer | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-05-04 12:11:42 +0100 |
commit | 6966bfa643305fde25e96bb938cad811cd3b4f31 (patch) | |
tree | 7c6e377e0e0ba74d9e963a94c6cdc8f03ce4a407 /src/armnnSerializer/test/SerializerTests.cpp | |
parent | b2d3ec5b1e938ef34facfdbcff83fc8e845d5f7c (diff) | |
download | armnn-6966bfa643305fde25e96bb938cad811cd3b4f31.tar.gz |
IVGCVSW-6856 Add GATHERNd Serializer and Deserializer
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Ibab3525d53edbdf6a48e43b2bf668fcd2efaba58
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 966dc6c669..a765290de8 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1109,6 +1109,70 @@ TEST_CASE("SerializeGather") deserializedNetwork->ExecuteStrategy(verifier); } +TEST_CASE("SerializeGatherNd") +{ + class GatherNdLayerVerifier : public LayerVerifierBase + { + public: + GatherNdLayerVerifier(const std::string& layerName, + const std::vector<armnn::TensorInfo>& inputInfos, + const std::vector<armnn::TensorInfo>& outputInfos) + : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + + void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override + { + armnn::IgnoreUnused(constants, id); + switch (layer->GetType()) + { + case armnn::LayerType::Input: + case armnn::LayerType::Output: + case armnn::LayerType::Constant: + break; + default: + { + VerifyNameAndConnections(layer, name); + } + } + } + }; + + const std::string layerName("gatherNd"); + armnn::TensorInfo paramsInfo({ 6, 3 }, armnn::DataType::QAsymmU8); + armnn::TensorInfo outputInfo({ 3, 3 }, armnn::DataType::QAsymmU8); + const armnn::TensorInfo indicesInfo({ 3, 1 }, armnn::DataType::Signed32, 0.0f, 0, true); + + paramsInfo.SetQuantizationScale(1.0f); + paramsInfo.SetQuantizationOffset(0); + outputInfo.SetQuantizationScale(1.0f); + outputInfo.SetQuantizationOffset(0); + + const std::vector<int32_t>& indicesData = {5, 1, 0}; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer *const constantLayer = + network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); + armnn::IConnectableLayer *const gatherNdLayer = network->AddGatherNdLayer(layerName.c_str()); + armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(gatherNdLayer->GetInputSlot(0)); + constantLayer->GetOutputSlot(0).Connect(gatherNdLayer->GetInputSlot(1)); + gatherNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(paramsInfo); + constantLayer->GetOutputSlot(0).SetTensorInfo(indicesInfo); + gatherNdLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + GatherNdLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}); + deserializedNetwork->ExecuteStrategy(verifier); +} TEST_CASE("SerializeComparisonGreater") { |