diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-09-23 16:12:19 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-10-07 14:43:09 +0000 |
commit | 452274c86245082ce20563ede12b92af81dba38a (patch) | |
tree | 79718c6cf86acbb21138068c17aae15c4b172306 /src/armnnOnnxParser/test/Gather.cpp | |
parent | 4d217c02fe2c0a32ff9da69d8fe375a75173c0f3 (diff) | |
download | armnn-452274c86245082ce20563ede12b92af81dba38a.tar.gz |
IVGCVSW-6459 Add support of scalar and flexible output datatypes to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Id1e933f6ae55ddc1a57c80c9f6a5757ccb61f018
Diffstat (limited to 'src/armnnOnnxParser/test/Gather.cpp')
-rw-r--r-- | src/armnnOnnxParser/test/Gather.cpp | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/src/armnnOnnxParser/test/Gather.cpp b/src/armnnOnnxParser/test/Gather.cpp index 1d214419c4..8fd9021ebc 100644 --- a/src/armnnOnnxParser/test/Gather.cpp +++ b/src/armnnOnnxParser/test/Gather.cpp @@ -85,6 +85,14 @@ struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPar } }; +struct GatherScalarFixture : GatherMainFixture +{ + GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { }) + { + Setup(); + } +}; + struct Gather1dFixture : GatherMainFixture { Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 }) @@ -117,16 +125,22 @@ struct Gather4dFixture : GatherMainFixture } }; +TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest") +{ + RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, + {{"output", { 1.0f }}}); +} + TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest") { - RunTest<1, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, - {{"output", {1.0f, 3.0f, 2.0f, 6.0f}}}); + RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, + {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}}); } TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest") { - RunTest<2, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, - {{"output", {3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); + RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, + {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); } TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest") |