diff options
Diffstat (limited to 'src/armnnOnnxParser/test/Gather.cpp')
-rw-r--r-- | src/armnnOnnxParser/test/Gather.cpp | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/src/armnnOnnxParser/test/Gather.cpp b/src/armnnOnnxParser/test/Gather.cpp index 1d214419c4..8fd9021ebc 100644 --- a/src/armnnOnnxParser/test/Gather.cpp +++ b/src/armnnOnnxParser/test/Gather.cpp @@ -85,6 +85,14 @@ struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPar } }; +struct GatherScalarFixture : GatherMainFixture +{ + GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { }) + { + Setup(); + } +}; + struct Gather1dFixture : GatherMainFixture { Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 }) @@ -117,16 +125,22 @@ struct Gather4dFixture : GatherMainFixture } }; +TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest") +{ + RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, + {{"output", { 1.0f }}}); +} + TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest") { - RunTest<1, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, - {{"output", {1.0f, 3.0f, 2.0f, 6.0f}}}); + RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, + {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}}); } TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest") { - RunTest<2, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, - {{"output", {3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); + RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, + {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); } TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest") |