aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/test
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-23 16:12:19 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-10-07 14:43:09 +0000
commit452274c86245082ce20563ede12b92af81dba38a (patch)
tree79718c6cf86acbb21138068c17aae15c4b172306 /src/armnnOnnxParser/test
parent4d217c02fe2c0a32ff9da69d8fe375a75173c0f3 (diff)
downloadarmnn-452274c86245082ce20563ede12b92af81dba38a.tar.gz
IVGCVSW-6459 Add support of scalar and flexible output datatypes to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Id1e933f6ae55ddc1a57c80c9f6a5757ccb61f018
Diffstat (limited to 'src/armnnOnnxParser/test')
-rw-r--r--src/armnnOnnxParser/test/Gather.cpp22
-rw-r--r--src/armnnOnnxParser/test/Shape.cpp56
-rw-r--r--src/armnnOnnxParser/test/Unsqueeze.cpp14
3 files changed, 54 insertions, 38 deletions
diff --git a/src/armnnOnnxParser/test/Gather.cpp b/src/armnnOnnxParser/test/Gather.cpp
index 1d214419c4..8fd9021ebc 100644
--- a/src/armnnOnnxParser/test/Gather.cpp
+++ b/src/armnnOnnxParser/test/Gather.cpp
@@ -85,6 +85,14 @@ struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPar
}
};
+struct GatherScalarFixture : GatherMainFixture
+{
+ GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { })
+ {
+ Setup();
+ }
+};
+
struct Gather1dFixture : GatherMainFixture
{
Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 })
@@ -117,16 +125,22 @@ struct Gather4dFixture : GatherMainFixture
}
};
+TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest")
+{
+ RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
+ {{"output", { 1.0f }}});
+}
+
TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest")
{
- RunTest<1, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
- {{"output", {1.0f, 3.0f, 2.0f, 6.0f}}});
+ RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
+ {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}});
}
TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest")
{
- RunTest<2, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
- {{"output", {3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
+ RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
+ {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
}
TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest")
diff --git a/src/armnnOnnxParser/test/Shape.cpp b/src/armnnOnnxParser/test/Shape.cpp
index b033b2d8bf..e01c4b8b55 100644
--- a/src/armnnOnnxParser/test/Shape.cpp
+++ b/src/armnnOnnxParser/test/Shape.cpp
@@ -5,9 +5,11 @@
#include "armnnOnnxParser/IOnnxParser.hpp"
#include "ParserPrototxtFixture.hpp"
+#include "OnnxParserTestUtils.hpp"
TEST_SUITE("OnnxParser_Shape")
{
+
struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
{
ShapeMainFixture(const std::string& inputType,
@@ -31,7 +33,7 @@ struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPars
tensor_type {
elem_type: )" + inputType + R"(
shape {
- )" + ConstructShapeString(inputShape) + R"(
+ )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
}
}
}
@@ -54,91 +56,77 @@ struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPars
version: 10
})";
}
- std::string ConstructShapeString(const std::vector<int>& shape)
- {
- std::string shapeStr;
- for (int i : shape)
- {
- shapeStr = fmt::format("{} dim {{ dim_value: {} }}", shapeStr, i);
- }
- return shapeStr;
- }
};
-struct ShapeValidFloatFixture : ShapeMainFixture
+struct ShapeFloatFixture : ShapeMainFixture
{
- ShapeValidFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 }) {
+ ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 })
+ {
Setup();
}
};
-struct ShapeValidIntFixture : ShapeMainFixture
+struct ShapeIntFixture : ShapeMainFixture
{
- ShapeValidIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 }) {
+ ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 })
+ {
Setup();
}
};
struct Shape3DFixture : ShapeMainFixture
{
- Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 }) {
+ Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 })
+ {
Setup();
}
};
struct Shape2DFixture : ShapeMainFixture
{
- Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 }) {
+ Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 })
+ {
Setup();
}
};
struct Shape1DFixture : ShapeMainFixture
{
- Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 }) {
+ Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 })
+ {
Setup();
}
};
-struct ShapeInvalidFixture : ShapeMainFixture
-{
- ShapeInvalidFixture() : ShapeMainFixture("1", "1", "4", { 1, 3, 1, 5 }) {}
-};
-
-TEST_CASE_FIXTURE(ShapeValidFloatFixture, "FloatValidShapeTest")
+TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}});
}
-TEST_CASE_FIXTURE(ShapeValidIntFixture, "IntValidShapeTest")
+TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest")
{
- RunTest<2, int>({{"Input", { 0, 1, 2, 3, 4,
+ RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4,
4, 3, 2, 1, 0,
0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}});
}
TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}});
}
TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
}
TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest")
{
- RunTest<2, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
-}
-
-TEST_CASE_FIXTURE(ShapeInvalidFixture, "IncorrectOutputDataShapeTest")
-{
- CHECK_THROWS_AS(Setup(), armnn::ParseException);
+ RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
}
}
diff --git a/src/armnnOnnxParser/test/Unsqueeze.cpp b/src/armnnOnnxParser/test/Unsqueeze.cpp
index 95a191e46b..7ba87bc680 100644
--- a/src/armnnOnnxParser/test/Unsqueeze.cpp
+++ b/src/armnnOnnxParser/test/Unsqueeze.cpp
@@ -77,6 +77,14 @@ struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture
}
};
+struct UnsqueezeScalarFixture : UnsqueezeFixture
+{
+ UnsqueezeScalarFixture() : UnsqueezeFixture({ 0 }, { }, { 1 })
+ {
+ Setup();
+ }
+};
+
TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest")
{
RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
@@ -107,6 +115,12 @@ TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest")
6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
}
+TEST_CASE_FIXTURE(UnsqueezeScalarFixture, "UnsqueezeScalarTest")
+{
+ RunTest<1, float>({{"Input", { 1.0f }}},
+ {{"Output", { 1.0f }}});
+}
+
struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
{
UnsqueezeInputAxesFixture()