aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/test/Gather.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/test/Gather.cpp')
-rw-r--r--src/armnnOnnxParser/test/Gather.cpp22
1 files changed, 18 insertions, 4 deletions
diff --git a/src/armnnOnnxParser/test/Gather.cpp b/src/armnnOnnxParser/test/Gather.cpp
index 1d214419c4..8fd9021ebc 100644
--- a/src/armnnOnnxParser/test/Gather.cpp
+++ b/src/armnnOnnxParser/test/Gather.cpp
@@ -85,6 +85,14 @@ struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPar
}
};
+struct GatherScalarFixture : GatherMainFixture
+{
+ GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { })
+ {
+ Setup();
+ }
+};
+
struct Gather1dFixture : GatherMainFixture
{
Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 })
@@ -117,16 +125,22 @@ struct Gather4dFixture : GatherMainFixture
}
};
+TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest")
+{
+ RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
+ {{"output", { 1.0f }}});
+}
+
TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest")
{
- RunTest<1, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
- {{"output", {1.0f, 3.0f, 2.0f, 6.0f}}});
+ RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
+ {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}});
}
TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest")
{
- RunTest<2, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
- {{"output", {3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
+ RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
+ {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
}
TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest")