aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancisMurtagh <francis.murtagh@arm.com>2019-01-24 10:53:39 +0000
committerFrancis Murtagh <francis.murtagh@arm.com>2019-01-24 12:43:25 +0000
commit94412aff782472be54dce4328e2ecee0225b3e97 (patch)
tree42de97582e1948ecd91bdb24de00a966d180f6e7
parent5f4e41ff596170dad9c073b007b3f53783a9e1f3 (diff)
downloadarmnn-94412aff782472be54dce4328e2ecee0225b3e97.tar.gz
IVGCVSW-2512 Add Gather operator parser to TfParser
* Add ParseGather to TFParser * Add Unit tests for Gather Operator !armnn:562 Change-Id: Idff45c2d3d8d683aa9eb2c4a63123c8d6054609e
-rw-r--r--CMakeLists.txt1
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp37
-rw-r--r--src/armnnTfParser/TfParser.hpp1
-rw-r--r--src/armnnTfParser/test/Gather.cpp167
4 files changed, 206 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3c1932d4dc..339efd0731 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -412,6 +412,7 @@ if(BUILD_UNIT_TESTS)
src/armnnTfParser/test/Equal.cpp
src/armnnTfParser/test/ExpandDims.cpp
src/armnnTfParser/test/FusedBatchNorm.cpp
+ src/armnnTfParser/test/Gather.cpp
src/armnnTfParser/test/Greater.cpp
src/armnnTfParser/test/Identity.cpp
src/armnnTfParser/test/LocalResponseNormalization.cpp
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 15a91d5275..43b0d86000 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -332,6 +332,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D },
{ "ExpandDims", &TfParser::ParseExpandDims },
{ "FusedBatchNorm", &TfParser::ParseFusedBatchNorm },
+ { "Gather", &TfParser::ParseGather},
{ "Greater", &TfParser::ParseGreater},
{ "ConcatV2", &TfParser::ParseConcat },
{ "LRN", &TfParser::ParseLrn },
@@ -1766,6 +1767,42 @@ ParsedTfOperationPtr TfParser::ProcessElementwiseLayer(
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+ IOutputSlot& params = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ IOutputSlot& indices = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index);
+
+ // Infer shape of output tensor
+ unsigned int paramsDim = params.GetTensorInfo().GetNumDimensions();
+ unsigned int indicesDim = indices.GetTensorInfo().GetNumDimensions();
+ unsigned int outputDim = paramsDim - 1 + indicesDim;
+
+ std::vector<unsigned int> dimSizes;
+
+ for (unsigned int i = 0; i < indicesDim; ++i)
+ {
+ dimSizes.push_back(indices.GetTensorInfo().GetShape()[i]);
+ }
+ for (unsigned int i = 1; i < paramsDim; ++i)
+ {
+ dimSizes.push_back(params.GetTensorInfo().GetShape()[i]);
+ }
+
+ const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data());
+
+ const TensorInfo inferredOutputInfo(inferredShape, params.GetTensorInfo().GetDataType());
+
+ IConnectableLayer* const layer = m_Network->AddGatherLayer(nodeDef.name().c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(inferredOutputInfo);
+
+ params.Connect(layer->GetInputSlot(0));
+ indices.Connect(layer->GetInputSlot(1));
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
ParsedTfOperationPtr TfParser::ParseGreater(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index f1b7205ff1..9a7d7827c3 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -163,6 +163,7 @@ private:
ParsedTfOperationPtr ParseEqual(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseMinimum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseGather(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseGreater(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp
new file mode 100644
index 0000000000..f40dc57556
--- /dev/null
+++ b/src/armnnTfParser/test/Gather.cpp
@@ -0,0 +1,167 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "armnnTfParser/ITfParser.hpp"
+
+#include "ParserPrototxtFixture.hpp"
+#include <PrototxtConversions.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+// helper for setting the dimensions in prototxt
+void dimsHelper(const std::vector<int>& dims, std::string& text){
+ for(u_int i=0; i<dims.size(); ++i){
+ text.append(R"(dim {
+ size: )");
+ text.append(std::to_string(dims[i]));
+ text.append(R"(
+ })");
+ }
+}
+
+// helper for converting from integer to octal representation
+void octalHelper(const std::vector<int>& indicesContent, std::string& text){
+ for (unsigned int i = 0; i < indicesContent.size(); ++i)
+ {
+ text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
+ }
+}
+
+struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ GatherFixture(const armnn::TensorShape& inputShape0,
+ const armnn::TensorShape& inputShape1,
+ const std::vector<int>& input1Content,
+ const std::vector<int>& input0Dims,
+ const std::vector<int>& input1Dims)
+ {
+ m_Prototext = R"(
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+)";
+ dimsHelper(input0Dims, m_Prototext);
+ m_Prototext.append(R"(
+ }
+ }
+ }
+}
+node {
+ name: "input1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+)");
+ dimsHelper(input1Dims, m_Prototext);
+ m_Prototext.append(R"(
+ }
+ tensor_content: ")");
+ octalHelper(input1Content, m_Prototext);
+ m_Prototext.append(R"("
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Gather"
+ input: "input0"
+ input: "input1"
+ attr {
+ key: "Tindices"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tparams"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )");
+ Setup({ { "input0", inputShape0 },
+ { "input1", inputShape1 } },
+ { "output" });
+
+ }
+};
+
+
+struct GatherFixture1DParams1DIndices : public GatherFixture
+{
+ GatherFixture1DParams1DIndices() : GatherFixture(
+ { 4, 1, 1, 1 },
+ { 4, 0, 0, 0 },
+ { 0, 2, 1, 3 },
+ { 4 },
+ { 4 }) {}
+};
+
+struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
+{
+ GatherFixture1DParamsMultiDimIndices() : GatherFixture(
+ { 4, 1, 1 },
+ { 2, 2, 1, 1 },
+ { 0, 1, 1, 3 },
+ { 4 },
+ { 2, 2 }) {}
+};
+
+struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
+{
+ GatherFixtureMultiDimParamMultiDimIndices() : GatherFixture(
+ { 5, 2, 1 },
+ { 2, 1, 4 },
+ { 1, 3, 0, 2 },
+ { 5, 2 },
+ { 2, 2 }) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
+{
+ RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
+
+ { { "output", { 1, 3, 2, 4 } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseGather1DParamsMultiDimIndices, GatherFixture1DParamsMultiDimIndices)
+{
+ RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
+
+ { { "output", { 1, 2, 2, 4 } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseGatherMultiDimParamMultiDimIndices, GatherFixtureMultiDimParamMultiDimIndices)
+{
+ RunTest<4>({ { "input0", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 } } },
+
+ { { "output", { 3, 4, 7, 8, 1, 2, 5, 6} } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()