aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 7a7c5a4375..38202fcf94 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1853,6 +1853,8 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef,
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
IOutputSlot& params = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
IOutputSlot& indices = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index);
+ GatherDescriptor descriptor;
+ descriptor.m_Axis = ReadMandatoryNodeInt32Attribute(nodeDef, "axis");
// Infer shape of output tensor
unsigned int paramsDim = params.GetTensorInfo().GetNumDimensions();
@@ -1874,7 +1876,7 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef,
const TensorInfo inferredOutputInfo(inferredShape, params.GetTensorInfo().GetDataType());
- IConnectableLayer* const layer = m_Network->AddGatherLayer(nodeDef.name().c_str());
+ IConnectableLayer* const layer = m_Network->AddGatherLayer(descriptor, nodeDef.name().c_str());
layer->GetOutputSlot(0).SetTensorInfo(inferredOutputInfo);
params.Connect(layer->GetInputSlot(0));