diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 15a91d5275..43b0d86000 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -332,6 +332,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope { "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D }, { "ExpandDims", &TfParser::ParseExpandDims }, { "FusedBatchNorm", &TfParser::ParseFusedBatchNorm }, + { "Gather", &TfParser::ParseGather}, { "Greater", &TfParser::ParseGreater}, { "ConcatV2", &TfParser::ParseConcat }, { "LRN", &TfParser::ParseLrn }, @@ -1766,6 +1767,42 @@ ParsedTfOperationPtr TfParser::ProcessElementwiseLayer( return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } +ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + IOutputSlot& params = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + IOutputSlot& indices = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index); + + // Infer shape of output tensor + unsigned int paramsDim = params.GetTensorInfo().GetNumDimensions(); + unsigned int indicesDim = indices.GetTensorInfo().GetNumDimensions(); + unsigned int outputDim = paramsDim - 1 + indicesDim; + + std::vector<unsigned int> dimSizes; + + for (unsigned int i = 0; i < indicesDim; ++i) + { + dimSizes.push_back(indices.GetTensorInfo().GetShape()[i]); + } + for (unsigned int i = 1; i < paramsDim; ++i) + { + dimSizes.push_back(params.GetTensorInfo().GetShape()[i]); + } + + const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data()); + + const TensorInfo inferredOutputInfo(inferredShape, params.GetTensorInfo().GetDataType()); + + IConnectableLayer* const layer = m_Network->AddGatherLayer(nodeDef.name().c_str()); + layer->GetOutputSlot(0).SetTensorInfo(inferredOutputInfo); + + params.Connect(layer->GetInputSlot(0)); + indices.Connect(layer->GetInputSlot(1)); + + return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); +} + ParsedTfOperationPtr TfParser::ParseGreater(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { |