From 526647333571169076f5e72c9fb18c71025bf7c0 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Mon, 29 Jun 2020 16:27:03 +0100 Subject: IVGCVSW-4903 Connect axis parameter in Gather from android to ACL. !android-nn-driver:3302 Signed-off-by: Teresa Charlin Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d --- src/armnnTfParser/TfParser.cpp | 6 ++++-- src/armnnTfParser/test/Gather.cpp | 26 +++++++++++++++++++++----- 2 files changed, 25 insertions(+), 7 deletions(-) (limited to 'src/armnnTfParser') diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 7a7c5a4375..38202fcf94 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -1853,6 +1853,8 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef, std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); IOutputSlot& params = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); IOutputSlot& indices = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index); + GatherDescriptor descriptor; + descriptor.m_Axis = ReadMandatoryNodeInt32Attribute(nodeDef, "axis"); // Infer shape of output tensor unsigned int paramsDim = params.GetTensorInfo().GetNumDimensions(); @@ -1874,7 +1876,7 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef, const TensorInfo inferredOutputInfo(inferredShape, params.GetTensorInfo().GetDataType()); - IConnectableLayer* const layer = m_Network->AddGatherLayer(nodeDef.name().c_str()); + IConnectableLayer* const layer = m_Network->AddGatherLayer(descriptor, nodeDef.name().c_str()); layer->GetOutputSlot(0).SetTensorInfo(inferredOutputInfo); params.Connect(layer->GetInputSlot(0)); diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp index 8c4b891141..ab5fb7104d 100644 --- a/src/armnnTfParser/test/Gather.cpp +++ b/src/armnnTfParser/test/Gather.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -38,7 +38,8 @@ struct GatherFixture : public armnnUtils::ParserPrototxtFixture& input1Content, const std::vector& input0Dims, - const std::vector& input1Dims) + const std::vector& input1Dims, + int axis = 0) { m_Prototext = R"( node { @@ -56,6 +57,7 @@ node { shape { )"; dimsHelper(input0Dims, m_Prototext); + m_Prototext.append(R"( } } @@ -78,6 +80,7 @@ node { tensor_shape { )"); dimsHelper(input1Dims, m_Prototext); + m_Prototext.append(R"( } tensor_content: ")"); @@ -104,8 +107,18 @@ node { type: DT_FLOAT } } + attr { + key: "axis" + value { + i: )"); + m_Prototext += std::to_string(axis); + + m_Prototext.append(R"( + } + } } )"); + Setup({ { "input0", inputShape0 }, { "input1", inputShape1 } }, { "output" }); @@ -121,7 +134,8 @@ struct GatherFixture1DParams1DIndices : public GatherFixture { 4, 0, 0, 0 }, { 0, 2, 1, 3 }, { 4 }, - { 4 }) {} + { 4 }, + 0) {} }; struct GatherFixture1DParamsMultiDimIndices : public GatherFixture @@ -131,7 +145,8 @@ struct GatherFixture1DParamsMultiDimIndices : public GatherFixture { 2, 2, 1, 1 }, { 0, 1, 1, 3 }, { 4 }, - { 2, 2 }) {} + { 2, 2 }, + 0) {} }; struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture @@ -141,7 +156,8 @@ struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture { 2, 1, 4 }, { 1, 3, 0, 2 }, { 5, 2 }, - { 2, 2 }) {} + { 2, 2 }, + 0) {} }; BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices) -- cgit v1.2.1