aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/SerializerTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp64
1 files changed, 64 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 966dc6c669..a765290de8 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1109,6 +1109,70 @@ TEST_CASE("SerializeGather")
deserializedNetwork->ExecuteStrategy(verifier);
}
+TEST_CASE("SerializeGatherNd")
+{
+ class GatherNdLayerVerifier : public LayerVerifierBase
+ {
+ public:
+ GatherNdLayerVerifier(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos)
+ : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
+ {
+ armnn::IgnoreUnused(constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input:
+ case armnn::LayerType::Output:
+ case armnn::LayerType::Constant:
+ break;
+ default:
+ {
+ VerifyNameAndConnections(layer, name);
+ }
+ }
+ }
+ };
+
+ const std::string layerName("gatherNd");
+ armnn::TensorInfo paramsInfo({ 6, 3 }, armnn::DataType::QAsymmU8);
+ armnn::TensorInfo outputInfo({ 3, 3 }, armnn::DataType::QAsymmU8);
+ const armnn::TensorInfo indicesInfo({ 3, 1 }, armnn::DataType::Signed32, 0.0f, 0, true);
+
+ paramsInfo.SetQuantizationScale(1.0f);
+ paramsInfo.SetQuantizationOffset(0);
+ outputInfo.SetQuantizationScale(1.0f);
+ outputInfo.SetQuantizationOffset(0);
+
+ const std::vector<int32_t>& indicesData = {5, 1, 0};
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer *const constantLayer =
+ network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
+ armnn::IConnectableLayer *const gatherNdLayer = network->AddGatherNdLayer(layerName.c_str());
+ armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
+
+ inputLayer->GetOutputSlot(0).Connect(gatherNdLayer->GetInputSlot(0));
+ constantLayer->GetOutputSlot(0).Connect(gatherNdLayer->GetInputSlot(1));
+ gatherNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(paramsInfo);
+ constantLayer->GetOutputSlot(0).SetTensorInfo(indicesInfo);
+ gatherNdLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ CHECK(deserializedNetwork);
+
+ GatherNdLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo});
+ deserializedNetwork->ExecuteStrategy(verifier);
+}
TEST_CASE("SerializeComparisonGreater")
{