aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/test/Shape.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/test/Shape.cpp')
-rw-r--r--src/armnnOnnxParser/test/Shape.cpp56
1 files changed, 22 insertions, 34 deletions
diff --git a/src/armnnOnnxParser/test/Shape.cpp b/src/armnnOnnxParser/test/Shape.cpp
index b033b2d8bf..e01c4b8b55 100644
--- a/src/armnnOnnxParser/test/Shape.cpp
+++ b/src/armnnOnnxParser/test/Shape.cpp
@@ -5,9 +5,11 @@
#include "armnnOnnxParser/IOnnxParser.hpp"
#include "ParserPrototxtFixture.hpp"
+#include "OnnxParserTestUtils.hpp"
TEST_SUITE("OnnxParser_Shape")
{
+
struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
{
ShapeMainFixture(const std::string& inputType,
@@ -31,7 +33,7 @@ struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPars
tensor_type {
elem_type: )" + inputType + R"(
shape {
- )" + ConstructShapeString(inputShape) + R"(
+ )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
}
}
}
@@ -54,91 +56,77 @@ struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPars
version: 10
})";
}
- std::string ConstructShapeString(const std::vector<int>& shape)
- {
- std::string shapeStr;
- for (int i : shape)
- {
- shapeStr = fmt::format("{} dim {{ dim_value: {} }}", shapeStr, i);
- }
- return shapeStr;
- }
};
-struct ShapeValidFloatFixture : ShapeMainFixture
+struct ShapeFloatFixture : ShapeMainFixture
{
- ShapeValidFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 }) {
+ ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 })
+ {
Setup();
}
};
-struct ShapeValidIntFixture : ShapeMainFixture
+struct ShapeIntFixture : ShapeMainFixture
{
- ShapeValidIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 }) {
+ ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 })
+ {
Setup();
}
};
struct Shape3DFixture : ShapeMainFixture
{
- Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 }) {
+ Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 })
+ {
Setup();
}
};
struct Shape2DFixture : ShapeMainFixture
{
- Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 }) {
+ Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 })
+ {
Setup();
}
};
struct Shape1DFixture : ShapeMainFixture
{
- Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 }) {
+ Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 })
+ {
Setup();
}
};
-struct ShapeInvalidFixture : ShapeMainFixture
-{
- ShapeInvalidFixture() : ShapeMainFixture("1", "1", "4", { 1, 3, 1, 5 }) {}
-};
-
-TEST_CASE_FIXTURE(ShapeValidFloatFixture, "FloatValidShapeTest")
+TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}});
}
-TEST_CASE_FIXTURE(ShapeValidIntFixture, "IntValidShapeTest")
+TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest")
{
- RunTest<2, int>({{"Input", { 0, 1, 2, 3, 4,
+ RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4,
4, 3, 2, 1, 0,
0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}});
}
TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}});
}
TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
}
TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest")
{
- RunTest<2, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
-}
-
-TEST_CASE_FIXTURE(ShapeInvalidFixture, "IncorrectOutputDataShapeTest")
-{
- CHECK_THROWS_AS(Setup(), armnn::ParseException);
+ RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
}
}