diff options
Diffstat (limited to 'src/armnnOnnxParser/test/Reshape.cpp')
-rw-r--r-- | src/armnnOnnxParser/test/Reshape.cpp | 254 |
1 files changed, 254 insertions, 0 deletions
diff --git a/src/armnnOnnxParser/test/Reshape.cpp b/src/armnnOnnxParser/test/Reshape.cpp index e9bcd278cf..97198761e5 100644 --- a/src/armnnOnnxParser/test/Reshape.cpp +++ b/src/armnnOnnxParser/test/Reshape.cpp @@ -5,6 +5,7 @@ #include "armnnOnnxParser/IOnnxParser.hpp" #include "ParserPrototxtFixture.hpp" +#include "OnnxParserTestUtils.hpp" TEST_SUITE("OnnxParser_Reshape") { @@ -211,4 +212,257 @@ TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape") CHECK_THROWS_AS(Setup(), armnn::ParseException); } +struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> +{ + ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape, + const std::vector<int>& shapeInputShape, + const std::vector<int>& outputShape, + const std::string& shape) + { + m_Prototext = R"( + ir_version: 3 + producer_name: "onnx-example" + graph { + name: "ReshapeGrapn" + input { + name: "Input" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( + } + } + } + } + input { + name: "Shape" + type { + tensor_type { + elem_type: 7 + shape { + )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"( + } + } + } + } + node { + input: "Input" + input: "Shape" + output: "Output" + name: "reshape" + op_type: "Reshape" + } + initializer { + dims: 2 + data_type: 7 + )" + shape + R"( + name: "Shape" + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( + } + } + } + } + } + opset_import { + version: 7 + })"; + } +}; + +struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture +{ + ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1") + { + Setup(); + } +}; + +struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture +{ + ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 }, + { 2 }, + { 2, 6 }, + "int64_data: -1 int64_data: 6") + { + Setup(); + } +}; + +struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture +{ + ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 }, + { 3 }, + { 3, 1, 4 }, + "int64_data: 3 int64_data: -1 int64_data: 4") + { + Setup(); + } +}; + +struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture +{ + ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture( + { 2, 3, 1, 2 }, + { 4 }, + { 3, 1, 2, 2 }, + "int64_data: 3 int64_data: 1 int64_data: 2 int64_data: -1") + { + Setup(); + } +}; + +TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest") +{ + RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest") +{ + RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest") +{ + RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest") +{ + RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); +} + +struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> +{ + ReshapeNonConstShapeFixture(const std::vector<int>& inputShape, + const std::vector<int>& shapeInputShape, + const std::vector<int>& outputShape) + { + m_Prototext = R"( + ir_version: 3 + producer_name: "onnx-example" + graph { + name: "ReshapeGrapn" + input { + name: "Input" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( + } + } + } + } + input { + name: "Shape" + type { + tensor_type { + elem_type: 7 + shape { + )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"( + } + } + } + } + node { + input: "Input" + input: "Shape" + output: "Output" + name: "reshape" + op_type: "Reshape" + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( + } + } + } + } + } + opset_import { + version: 7 + })"; + } +}; + +struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture +{ + ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }) + { + Setup(); + } +}; + +struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture +{ + ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 }) + { + Setup(); + } +}; + +struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture +{ + ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 }) + { + } +}; + +struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture +{ + ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 }) + { + } +}; + +TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest") +{ + RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest") +{ + RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest") +{ + CHECK_THROWS_AS(Setup(), armnn::ParseException); +} + +TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest") +{ + CHECK_THROWS_AS(Setup(), armnn::ParseException); +} + } |