aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/test/Reshape.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/test/Reshape.cpp')
-rw-r--r--src/armnnOnnxParser/test/Reshape.cpp254
1 files changed, 254 insertions, 0 deletions
diff --git a/src/armnnOnnxParser/test/Reshape.cpp b/src/armnnOnnxParser/test/Reshape.cpp
index e9bcd278cf..97198761e5 100644
--- a/src/armnnOnnxParser/test/Reshape.cpp
+++ b/src/armnnOnnxParser/test/Reshape.cpp
@@ -5,6 +5,7 @@
#include "armnnOnnxParser/IOnnxParser.hpp"
#include "ParserPrototxtFixture.hpp"
+#include "OnnxParserTestUtils.hpp"
TEST_SUITE("OnnxParser_Reshape")
{
@@ -211,4 +212,257 @@ TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape")
CHECK_THROWS_AS(Setup(), armnn::ParseException);
}
+struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+ ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape,
+ const std::vector<int>& shapeInputShape,
+ const std::vector<int>& outputShape,
+ const std::string& shape)
+ {
+ m_Prototext = R"(
+ ir_version: 3
+ producer_name: "onnx-example"
+ graph {
+ name: "ReshapeGrapn"
+ input {
+ name: "Input"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
+ }
+ }
+ }
+ }
+ input {
+ name: "Shape"
+ type {
+ tensor_type {
+ elem_type: 7
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
+ }
+ }
+ }
+ }
+ node {
+ input: "Input"
+ input: "Shape"
+ output: "Output"
+ name: "reshape"
+ op_type: "Reshape"
+ }
+ initializer {
+ dims: 2
+ data_type: 7
+ )" + shape + R"(
+ name: "Shape"
+ }
+ output {
+ name: "Output"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
+ }
+ }
+ }
+ }
+ }
+ opset_import {
+ version: 7
+ })";
+ }
+};
+
+struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture
+{
+ ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1")
+ {
+ Setup();
+ }
+};
+
+struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture
+{
+ ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
+ { 2 },
+ { 2, 6 },
+ "int64_data: -1 int64_data: 6")
+ {
+ Setup();
+ }
+};
+
+struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture
+{
+ ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
+ { 3 },
+ { 3, 1, 4 },
+ "int64_data: 3 int64_data: -1 int64_data: 4")
+ {
+ Setup();
+ }
+};
+
+struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture
+{
+ ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture(
+ { 2, 3, 1, 2 },
+ { 4 },
+ { 3, 1, 2, 2 },
+ "int64_data: 3 int64_data: 1 int64_data: 2 int64_data: -1")
+ {
+ Setup();
+ }
+};
+
+TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest")
+{
+ RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest")
+{
+ RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest")
+{
+ RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest")
+{
+ RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
+}
+
+struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+ ReshapeNonConstShapeFixture(const std::vector<int>& inputShape,
+ const std::vector<int>& shapeInputShape,
+ const std::vector<int>& outputShape)
+ {
+ m_Prototext = R"(
+ ir_version: 3
+ producer_name: "onnx-example"
+ graph {
+ name: "ReshapeGrapn"
+ input {
+ name: "Input"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
+ }
+ }
+ }
+ }
+ input {
+ name: "Shape"
+ type {
+ tensor_type {
+ elem_type: 7
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
+ }
+ }
+ }
+ }
+ node {
+ input: "Input"
+ input: "Shape"
+ output: "Output"
+ name: "reshape"
+ op_type: "Reshape"
+ }
+ output {
+ name: "Output"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
+ }
+ }
+ }
+ }
+ }
+ opset_import {
+ version: 7
+ })";
+ }
+};
+
+struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture
+{
+ ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 })
+ {
+ Setup();
+ }
+};
+
+struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture
+{
+ ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 })
+ {
+ Setup();
+ }
+};
+
+struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture
+{
+ ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 })
+ {
+ }
+};
+
+struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture
+{
+ ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 })
+ {
+ }
+};
+
+TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest")
+{
+ RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest")
+{
+ RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
+ 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
+ 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest")
+{
+ CHECK_THROWS_AS(Setup(), armnn::ParseException);
+}
+
+TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest")
+{
+ CHECK_THROWS_AS(Setup(), armnn::ParseException);
+}
+
}