aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorFrancisMurtagh <francis.murtagh@arm.com>2019-01-24 10:53:39 +0000
committerFrancis Murtagh <francis.murtagh@arm.com>2019-01-24 12:43:25 +0000
commit94412aff782472be54dce4328e2ecee0225b3e97 (patch)
tree42de97582e1948ecd91bdb24de00a966d180f6e7 /src/armnnTfParser/TfParser.cpp
parent5f4e41ff596170dad9c073b007b3f53783a9e1f3 (diff)
downloadarmnn-94412aff782472be54dce4328e2ecee0225b3e97.tar.gz
IVGCVSW-2512 Add Gather operator parser to TfParser
* Add ParseGather to TFParser * Add Unit tests for Gather Operator !armnn:562 Change-Id: Idff45c2d3d8d683aa9eb2c4a63123c8d6054609e
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp37
1 files changed, 37 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 15a91d5275..43b0d86000 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -332,6 +332,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D },
{ "ExpandDims", &TfParser::ParseExpandDims },
{ "FusedBatchNorm", &TfParser::ParseFusedBatchNorm },
+ { "Gather", &TfParser::ParseGather},
{ "Greater", &TfParser::ParseGreater},
{ "ConcatV2", &TfParser::ParseConcat },
{ "LRN", &TfParser::ParseLrn },
@@ -1766,6 +1767,42 @@ ParsedTfOperationPtr TfParser::ProcessElementwiseLayer(
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+ IOutputSlot& params = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ IOutputSlot& indices = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index);
+
+ // Infer shape of output tensor
+ unsigned int paramsDim = params.GetTensorInfo().GetNumDimensions();
+ unsigned int indicesDim = indices.GetTensorInfo().GetNumDimensions();
+ unsigned int outputDim = paramsDim - 1 + indicesDim;
+
+ std::vector<unsigned int> dimSizes;
+
+ for (unsigned int i = 0; i < indicesDim; ++i)
+ {
+ dimSizes.push_back(indices.GetTensorInfo().GetShape()[i]);
+ }
+ for (unsigned int i = 1; i < paramsDim; ++i)
+ {
+ dimSizes.push_back(params.GetTensorInfo().GetShape()[i]);
+ }
+
+ const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data());
+
+ const TensorInfo inferredOutputInfo(inferredShape, params.GetTensorInfo().GetDataType());
+
+ IConnectableLayer* const layer = m_Network->AddGatherLayer(nodeDef.name().c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(inferredOutputInfo);
+
+ params.Connect(layer->GetInputSlot(0));
+ indices.Connect(layer->GetInputSlot(1));
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
ParsedTfOperationPtr TfParser::ParseGreater(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{