aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2022-04-25 17:14:50 +0100
committerTeresa Charlin <teresa.charlinreyes@arm.com>2022-05-04 12:11:42 +0100
commit6966bfa643305fde25e96bb938cad811cd3b4f31 (patch)
tree7c6e377e0e0ba74d9e963a94c6cdc8f03ce4a407
parentb2d3ec5b1e938ef34facfdbcff83fc8e845d5f7c (diff)
downloadarmnn-6966bfa643305fde25e96bb938cad811cd3b4f31.tar.gz
IVGCVSW-6856 Add GATHERNd Serializer and Deserializer
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Ibab3525d53edbdf6a48e43b2bf668fcd2efaba58
-rw-r--r--CMakeLists.txt1
-rw-r--r--docs/05_02_deserializer_serializer.dox3
-rw-r--r--src/armnnDeserializer/Deserializer.cpp23
-rw-r--r--src/armnnDeserializer/Deserializer.hpp1
-rw-r--r--src/armnnDeserializer/test/DeserializeGatherNd.cpp146
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs6
-rw-r--r--src/armnnSerializer/ArmnnSchema_generated.h89
-rw-r--r--src/armnnSerializer/Serializer.cpp16
-rw-r--r--src/armnnSerializer/Serializer.hpp3
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp64
10 files changed, 341 insertions, 11 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d3d1facc9b..267ad37ce8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -743,6 +743,7 @@ if(BUILD_UNIT_TESTS)
src/armnnDeserializer/test/DeserializeFloor.cpp
src/armnnDeserializer/test/DeserializeFullyConnected.cpp
src/armnnDeserializer/test/DeserializeGather.cpp
+ src/armnnDeserializer/test/DeserializeGatherNd.cpp
src/armnnDeserializer/test/DeserializeInstanceNormalization.cpp
src/armnnDeserializer/test/DeserializeL2Normalization.cpp
src/armnnDeserializer/test/DeserializeLogSoftmax.cpp
diff --git a/docs/05_02_deserializer_serializer.dox b/docs/05_02_deserializer_serializer.dox
index 84324d89d5..6cfaf29968 100644
--- a/docs/05_02_deserializer_serializer.dox
+++ b/docs/05_02_deserializer_serializer.dox
@@ -41,6 +41,7 @@ The Arm NN SDK Serializer currently supports the following layers:
- Floor
- FullyConnected
- Gather
+- GatherNd
- Input
- InstanceNormalization
- L2Normalization
@@ -131,6 +132,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
- Floor
- FullyConnected
- Gather
+- GatherNd
- Input
- InstanceNormalization
- L2Normalization
@@ -147,6 +149,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
- Pad
- Permute
- Pooling2d
+- Pooling3d
- Prelu
- Quantize
- QLstm
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 11d3542405..75c60cc906 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -234,6 +234,7 @@ m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupporte
m_ParserFunctions[Layer_FillLayer] = &DeserializerImpl::ParseFill;
m_ParserFunctions[Layer_FloorLayer] = &DeserializerImpl::ParseFloor;
m_ParserFunctions[Layer_GatherLayer] = &DeserializerImpl::ParseGather;
+ m_ParserFunctions[Layer_GatherNdLayer] = &DeserializerImpl::ParseGatherNd;
m_ParserFunctions[Layer_GreaterLayer] = &DeserializerImpl::ParseGreater;
m_ParserFunctions[Layer_InstanceNormalizationLayer] = &DeserializerImpl::ParseInstanceNormalization;
m_ParserFunctions[Layer_L2NormalizationLayer] = &DeserializerImpl::ParseL2Normalization;
@@ -331,6 +332,8 @@ LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& gr
return graphPtr->layers()->Get(layerIndex)->layer_as_FloorLayer()->base();
case Layer::Layer_GatherLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_GatherLayer()->base();
+ case Layer::Layer_GatherNdLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_GatherNdLayer()->base();
case Layer::Layer_GreaterLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_GreaterLayer()->base();
case Layer::Layer_InputLayer:
@@ -2933,6 +2936,26 @@ void IDeserializer::DeserializerImpl::ParseGather(GraphPtr graph, unsigned int l
RegisterOutputSlots(graph, layerIndex, layer);
}
+void IDeserializer::DeserializerImpl::ParseGatherNd(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+
+ TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto layerName = GetLayerName(graph, layerIndex);
+ IConnectableLayer* layer = m_Network->AddGatherNdLayer(layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
void IDeserializer::DeserializerImpl::ParseMean(GraphPtr graph, unsigned int layerIndex)
{
CHECK_LAYERS(graph, 0, layerIndex);
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 8de492ed5f..277c09ae48 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -108,6 +108,7 @@ private:
void ParseFloor(GraphPtr graph, unsigned int layerIndex);
void ParseFullyConnected(GraphPtr graph, unsigned int layerIndex);
void ParseGather(GraphPtr graph, unsigned int layerIndex);
+ void ParseGatherNd(GraphPtr graph, unsigned int layerIndex);
void ParseGreater(GraphPtr graph, unsigned int layerIndex);
void ParseInstanceNormalization(GraphPtr graph, unsigned int layerIndex);
void ParseL2Normalization(GraphPtr graph, unsigned int layerIndex);
diff --git a/src/armnnDeserializer/test/DeserializeGatherNd.cpp b/src/armnnDeserializer/test/DeserializeGatherNd.cpp
new file mode 100644
index 0000000000..684a42ca07
--- /dev/null
+++ b/src/armnnDeserializer/test/DeserializeGatherNd.cpp
@@ -0,0 +1,146 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserFlatbuffersSerializeFixture.hpp"
+#include <armnnDeserializer/IDeserializer.hpp>
+
+#include <string>
+
+TEST_SUITE("Deserializer_GatherNd")
+{
+struct GatherNdFixture : public ParserFlatbuffersSerializeFixture
+{
+ explicit GatherNdFixture(const std::string& paramsShape,
+ const std::string& indicesShape,
+ const std::string& outputShape,
+ const std::string& indicesData,
+ const std::string dataType,
+ const std::string constDataType)
+ {
+ m_JsonString = R"(
+ {
+ inputIds: [0],
+ outputIds: [3],
+ layers: [
+ {
+ layer_type: "InputLayer",
+ layer: {
+ base: {
+ layerBindingId: 0,
+ base: {
+ index: 0,
+ layerName: "InputLayer",
+ layerType: "Input",
+ inputSlots: [{
+ index: 0,
+ connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+ }],
+ outputSlots: [ {
+ index: 0,
+ tensorInfo: {
+ dimensions: )" + paramsShape + R"(,
+ dataType: )" + dataType + R"(
+ }}]
+ }
+ }}},
+ {
+ layer_type: "ConstantLayer",
+ layer: {
+ base: {
+ index:1,
+ layerName: "ConstantLayer",
+ layerType: "Constant",
+ outputSlots: [ {
+ index: 0,
+ tensorInfo: {
+ dimensions: )" + indicesShape + R"(,
+ dataType: "Signed32",
+ },
+ }],
+ },
+ input: {
+ info: {
+ dimensions: )" + indicesShape + R"(,
+ dataType: )" + dataType + R"(
+ },
+ data_type: )" + constDataType + R"(,
+ data: {
+ data: )" + indicesData + R"(,
+ } }
+ },},
+ {
+ layer_type: "GatherNdLayer",
+ layer: {
+ base: {
+ index: 2,
+ layerName: "GatherNdLayer",
+ layerType: "GatherNd",
+ inputSlots: [
+ {
+ index: 0,
+ connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+ },
+ {
+ index: 1,
+ connection: {sourceLayerIndex:1, outputSlotIndex:0 }
+ }],
+ outputSlots: [ {
+ index: 0,
+ tensorInfo: {
+ dimensions: )" + outputShape + R"(,
+ dataType: )" + dataType + R"(
+
+ }}]},
+ }},
+ {
+ layer_type: "OutputLayer",
+ layer: {
+ base:{
+ layerBindingId: 0,
+ base: {
+ index: 3,
+ layerName: "OutputLayer",
+ layerType: "Output",
+ inputSlots: [{
+ index: 0,
+ connection: {sourceLayerIndex:2, outputSlotIndex:0 },
+ }],
+ outputSlots: [ {
+ index: 0,
+ tensorInfo: {
+ dimensions: )" + outputShape + R"(,
+ dataType: )" + dataType + R"(
+ },
+ }],
+ }}},
+ }]
+ } )";
+
+ Setup();
+ }
+};
+
+struct SimpleGatherNdFixtureFloat32 : GatherNdFixture
+{
+ SimpleGatherNdFixtureFloat32() : GatherNdFixture("[ 6, 3 ]", "[ 3, 1 ]", "[ 3, 3 ]",
+ "[ 5, 1, 0 ]", "Float32", "IntData") {}
+};
+
+TEST_CASE_FIXTURE(SimpleGatherNdFixtureFloat32, "GatherNdFloat32")
+{
+ RunTest<4, armnn::DataType::Float32>(0,
+ {{"InputLayer", { 1, 2, 3,
+ 4, 5, 6,
+ 7, 8, 9,
+ 10, 11, 12,
+ 13, 14, 15,
+ 16, 17, 18 }}},
+ {{"OutputLayer", { 16, 17, 18,
+ 4, 5, 6,
+ 1, 2, 3}}});
+}
+
+}
+
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index c8ffce48bc..f301fce818 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -181,6 +181,7 @@ enum LayerType : uint {
ChannelShuffle = 64,
Convolution3d = 65,
Pooling3d = 66,
+ GatherNd = 67,
}
// Base layer table to be used as part of other layers
@@ -382,6 +383,10 @@ table GatherDescriptor {
axis:int = 0;
}
+table GatherNdLayer {
+ base:LayerBase;
+}
+
/// @deprecated Use ComparisonLayer instead
table GreaterLayer {
base:LayerBase;
@@ -1072,6 +1077,7 @@ union Layer {
ChannelShuffleLayer,
Convolution3dLayer,
Pooling3dLayer,
+ GatherNdLayer,
}
table AnyLayer {
diff --git a/src/armnnSerializer/ArmnnSchema_generated.h b/src/armnnSerializer/ArmnnSchema_generated.h
index 76a6460c85..8f803f5af2 100644
--- a/src/armnnSerializer/ArmnnSchema_generated.h
+++ b/src/armnnSerializer/ArmnnSchema_generated.h
@@ -1,5 +1,5 @@
//
-// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
// automatically generated by the FlatBuffers compiler, do not modify
@@ -131,6 +131,9 @@ struct GatherLayerBuilder;
struct GatherDescriptor;
struct GatherDescriptorBuilder;
+struct GatherNdLayer;
+struct GatherNdLayerBuilder;
+
struct GreaterLayer;
struct GreaterLayerBuilder;
@@ -777,11 +780,12 @@ enum LayerType {
LayerType_ChannelShuffle = 64,
LayerType_Convolution3d = 65,
LayerType_Pooling3d = 66,
+ LayerType_GatherNd = 67,
LayerType_MIN = LayerType_Addition,
- LayerType_MAX = LayerType_Pooling3d
+ LayerType_MAX = LayerType_GatherNd
};
-inline const LayerType (&EnumValuesLayerType())[67] {
+inline const LayerType (&EnumValuesLayerType())[68] {
static const LayerType values[] = {
LayerType_Addition,
LayerType_Input,
@@ -849,13 +853,14 @@ inline const LayerType (&EnumValuesLayerType())[67] {
LayerType_UnidirectionalSequenceLstm,
LayerType_ChannelShuffle,
LayerType_Convolution3d,
- LayerType_Pooling3d
+ LayerType_Pooling3d,
+ LayerType_GatherNd
};
return values;
}
inline const char * const *EnumNamesLayerType() {
- static const char * const names[68] = {
+ static const char * const names[69] = {
"Addition",
"Input",
"Multiplication",
@@ -923,13 +928,14 @@ inline const char * const *EnumNamesLayerType() {
"ChannelShuffle",
"Convolution3d",
"Pooling3d",
+ "GatherNd",
nullptr
};
return names;
}
inline const char *EnumNameLayerType(LayerType e) {
- if (flatbuffers::IsOutRange(e, LayerType_Addition, LayerType_Pooling3d)) return "";
+ if (flatbuffers::IsOutRange(e, LayerType_Addition, LayerType_GatherNd)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesLayerType()[index];
}
@@ -1309,11 +1315,12 @@ enum Layer {
Layer_ChannelShuffleLayer = 65,
Layer_Convolution3dLayer = 66,
Layer_Pooling3dLayer = 67,
+ Layer_GatherNdLayer = 68,
Layer_MIN = Layer_NONE,
- Layer_MAX = Layer_Pooling3dLayer
+ Layer_MAX = Layer_GatherNdLayer
};
-inline const Layer (&EnumValuesLayer())[68] {
+inline const Layer (&EnumValuesLayer())[69] {
static const Layer values[] = {
Layer_NONE,
Layer_ActivationLayer,
@@ -1382,13 +1389,14 @@ inline const Layer (&EnumValuesLayer())[68] {
Layer_UnidirectionalSequenceLstmLayer,
Layer_ChannelShuffleLayer,
Layer_Convolution3dLayer,
- Layer_Pooling3dLayer
+ Layer_Pooling3dLayer,
+ Layer_GatherNdLayer
};
return values;
}
inline const char * const *EnumNamesLayer() {
- static const char * const names[69] = {
+ static const char * const names[70] = {
"NONE",
"ActivationLayer",
"AdditionLayer",
@@ -1457,13 +1465,14 @@ inline const char * const *EnumNamesLayer() {
"ChannelShuffleLayer",
"Convolution3dLayer",
"Pooling3dLayer",
+ "GatherNdLayer",
nullptr
};
return names;
}
inline const char *EnumNameLayer(Layer e) {
- if (flatbuffers::IsOutRange(e, Layer_NONE, Layer_Pooling3dLayer)) return "";
+ if (flatbuffers::IsOutRange(e, Layer_NONE, Layer_GatherNdLayer)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesLayer()[index];
}
@@ -1740,6 +1749,10 @@ template<> struct LayerTraits<armnnSerializer::Pooling3dLayer> {
static const Layer enum_value = Layer_Pooling3dLayer;
};
+template<> struct LayerTraits<armnnSerializer::GatherNdLayer> {
+ static const Layer enum_value = Layer_GatherNdLayer;
+};
+
bool VerifyLayer(flatbuffers::Verifier &verifier, const void *obj, Layer type);
bool VerifyLayerVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
@@ -4186,6 +4199,49 @@ inline flatbuffers::Offset<GatherDescriptor> CreateGatherDescriptor(
return builder_.Finish();
}
+struct GatherNdLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef GatherNdLayerBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_BASE = 4
+ };
+ const armnnSerializer::LayerBase *base() const {
+ return GetPointer<const armnnSerializer::LayerBase *>(VT_BASE);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_BASE) &&
+ verifier.VerifyTable(base()) &&
+ verifier.EndTable();
+ }
+};
+
+struct GatherNdLayerBuilder {
+ typedef GatherNdLayer Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_base(flatbuffers::Offset<armnnSerializer::LayerBase> base) {
+ fbb_.AddOffset(GatherNdLayer::VT_BASE, base);
+ }
+ explicit GatherNdLayerBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ GatherNdLayerBuilder &operator=(const GatherNdLayerBuilder &);
+ flatbuffers::Offset<GatherNdLayer> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<GatherNdLayer>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<GatherNdLayer> CreateGatherNdLayer(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<armnnSerializer::LayerBase> base = 0) {
+ GatherNdLayerBuilder builder_(_fbb);
+ builder_.add_base(base);
+ return builder_.Finish();
+}
+
/// @deprecated Use ComparisonLayer instead
struct GreaterLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef GreaterLayerBuilder Builder;
@@ -10534,6 +10590,9 @@ struct AnyLayer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const armnnSerializer::Pooling3dLayer *layer_as_Pooling3dLayer() const {
return layer_type() == armnnSerializer::Layer_Pooling3dLayer ? static_cast<const armnnSerializer::Pooling3dLayer *>(layer()) : nullptr;
}
+ const armnnSerializer::GatherNdLayer *layer_as_GatherNdLayer() const {
+ return layer_type() == armnnSerializer::Layer_GatherNdLayer ? static_cast<const armnnSerializer::GatherNdLayer *>(layer()) : nullptr;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_LAYER_TYPE) &&
@@ -10811,6 +10870,10 @@ template<> inline const armnnSerializer::Pooling3dLayer *AnyLayer::layer_as<armn
return layer_as_Pooling3dLayer();
}
+template<> inline const armnnSerializer::GatherNdLayer *AnyLayer::layer_as<armnnSerializer::GatherNdLayer>() const {
+ return layer_as_GatherNdLayer();
+}
+
struct AnyLayerBuilder {
typedef AnyLayer Table;
flatbuffers::FlatBufferBuilder &fbb_;
@@ -11309,6 +11372,10 @@ inline bool VerifyLayer(flatbuffers::Verifier &verifier, const void *obj, Layer
auto ptr = reinterpret_cast<const armnnSerializer::Pooling3dLayer *>(obj);
return verifier.VerifyTable(ptr);
}
+ case Layer_GatherNdLayer: {
+ auto ptr = reinterpret_cast<const armnnSerializer::GatherNdLayer *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return true;
}
}
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 971621d60c..3b9dfb0ae8 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -587,6 +587,17 @@ void SerializerStrategy::SerializeGatherLayer(const armnn::IConnectableLayer* la
CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
}
+void SerializerStrategy::SerializeGatherNdLayer(const armnn::IConnectableLayer* layer,
+ const char* name)
+{
+ IgnoreUnused(name);
+
+ auto fbGatherNdBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_GatherNd);
+ auto flatBufferLayer = serializer::CreateGatherNdLayer(m_flatBufferBuilder, fbGatherNdBaseLayer);
+
+ CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherNdLayer);
+}
+
void SerializerStrategy::SerializeInstanceNormalizationLayer(
const armnn::IConnectableLayer* layer,
const armnn::InstanceNormalizationDescriptor& instanceNormalizationDescriptor,
@@ -2134,6 +2145,11 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
SerializeGatherLayer(layer, layerDescriptor, name);
break;
}
+ case armnn::LayerType::GatherNd :
+ {
+ SerializeGatherNdLayer(layer, name);
+ break;
+ }
case armnn::LayerType::Input:
{
SerializeInputLayer(layer, id, name);
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 3905e49cd1..98c1984cd2 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -191,6 +191,9 @@ private:
const armnn::GatherDescriptor& gatherDescriptor,
const char* name = nullptr);
+ void SerializeGatherNdLayer(const armnn::IConnectableLayer* layer,
+ const char* name = nullptr);
+
void SerializeInputLayer(const armnn::IConnectableLayer* layer,
armnn::LayerBindingId id,
const char* name = nullptr);
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 966dc6c669..a765290de8 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1109,6 +1109,70 @@ TEST_CASE("SerializeGather")
deserializedNetwork->ExecuteStrategy(verifier);
}
+TEST_CASE("SerializeGatherNd")
+{
+ class GatherNdLayerVerifier : public LayerVerifierBase
+ {
+ public:
+ GatherNdLayerVerifier(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos)
+ : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
+ {
+ armnn::IgnoreUnused(constants, id);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::Input:
+ case armnn::LayerType::Output:
+ case armnn::LayerType::Constant:
+ break;
+ default:
+ {
+ VerifyNameAndConnections(layer, name);
+ }
+ }
+ }
+ };
+
+ const std::string layerName("gatherNd");
+ armnn::TensorInfo paramsInfo({ 6, 3 }, armnn::DataType::QAsymmU8);
+ armnn::TensorInfo outputInfo({ 3, 3 }, armnn::DataType::QAsymmU8);
+ const armnn::TensorInfo indicesInfo({ 3, 1 }, armnn::DataType::Signed32, 0.0f, 0, true);
+
+ paramsInfo.SetQuantizationScale(1.0f);
+ paramsInfo.SetQuantizationOffset(0);
+ outputInfo.SetQuantizationScale(1.0f);
+ outputInfo.SetQuantizationOffset(0);
+
+ const std::vector<int32_t>& indicesData = {5, 1, 0};
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer *const constantLayer =
+ network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
+ armnn::IConnectableLayer *const gatherNdLayer = network->AddGatherNdLayer(layerName.c_str());
+ armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
+
+ inputLayer->GetOutputSlot(0).Connect(gatherNdLayer->GetInputSlot(0));
+ constantLayer->GetOutputSlot(0).Connect(gatherNdLayer->GetInputSlot(1));
+ gatherNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(paramsInfo);
+ constantLayer->GetOutputSlot(0).SetTensorInfo(indicesInfo);
+ gatherNdLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ CHECK(deserializedNetwork);
+
+ GatherNdLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo});
+ deserializedNetwork->ExecuteStrategy(verifier);
+}
TEST_CASE("SerializeComparisonGreater")
{