diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 7a7c5a4375..38202fcf94 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1853,6 +1853,8 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef, std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); IOutputSlot& params = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); IOutputSlot& indices = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index); + GatherDescriptor descriptor; + descriptor.m_Axis = ReadMandatoryNodeInt32Attribute(nodeDef, "axis"); // Infer shape of output tensor unsigned int paramsDim = params.GetTensorInfo().GetNumDimensions(); @@ -1874,7 +1876,7 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef, const TensorInfo inferredOutputInfo(inferredShape, params.GetTensorInfo().GetDataType()); - IConnectableLayer* const layer = m_Network->AddGatherLayer(nodeDef.name().c_str()); + IConnectableLayer* const layer = m_Network->AddGatherLayer(descriptor, nodeDef.name().c_str()); layer->GetOutputSlot(0).SetTensorInfo(inferredOutputInfo); params.Connect(layer->GetInputSlot(0)); |