diff options
Diffstat (limited to 'src/armnnTfParser/test/Gather.cpp')
-rw-r--r-- | src/armnnTfParser/test/Gather.cpp | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp index 8c4b891141..ab5fb7104d 100644 --- a/src/armnnTfParser/test/Gather.cpp +++ b/src/armnnTfParser/test/Gather.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -38,7 +38,8 @@ struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::I const armnn::TensorShape& inputShape1, const std::vector<int>& input1Content, const std::vector<int>& input0Dims, - const std::vector<int>& input1Dims) + const std::vector<int>& input1Dims, + int axis = 0) { m_Prototext = R"( node { @@ -56,6 +57,7 @@ node { shape { )"; dimsHelper(input0Dims, m_Prototext); + m_Prototext.append(R"( } } @@ -78,6 +80,7 @@ node { tensor_shape { )"); dimsHelper(input1Dims, m_Prototext); + m_Prototext.append(R"( } tensor_content: ")"); @@ -104,8 +107,18 @@ node { type: DT_FLOAT } } + attr { + key: "axis" + value { + i: )"); + m_Prototext += std::to_string(axis); + + m_Prototext.append(R"( + } + } } )"); + Setup({ { "input0", inputShape0 }, { "input1", inputShape1 } }, { "output" }); @@ -121,7 +134,8 @@ struct GatherFixture1DParams1DIndices : public GatherFixture { 4, 0, 0, 0 }, { 0, 2, 1, 3 }, { 4 }, - { 4 }) {} + { 4 }, + 0) {} }; struct GatherFixture1DParamsMultiDimIndices : public GatherFixture @@ -131,7 +145,8 @@ struct GatherFixture1DParamsMultiDimIndices : public GatherFixture { 2, 2, 1, 1 }, { 0, 1, 1, 3 }, { 4 }, - { 2, 2 }) {} + { 2, 2 }, + 0) {} }; struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture @@ -141,7 +156,8 @@ struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture { 2, 1, 4 }, { 1, 3, 0, 2 }, { 5, 2 }, - { 2, 2 }) {} + { 2, 2 }, + 0) {} }; BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices) |