aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-10-18 12:35:19 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2021-10-21 13:14:51 +0000
commit4b536e323abd09da9630502a8fb7d0be50e1ad45 (patch)
tree2b819a7673476fbbceac43ca0ed3ff6db492253d
parentf437213e4b54f0179129395828e549c02973e02f (diff)
downloadarmnn-4b536e323abd09da9630502a8fb7d0be50e1ad45.tar.gz
IVGCVSW-6451 Add support for Reshape when the target shape is dynamic
and batch size is unknown to ONNX parser Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I46b2daccce9e1a21d9d0550ac4126d2c79dbd37b
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp55
-rw-r--r--src/armnnOnnxParser/test/Reshape.cpp254
2 files changed, 292 insertions, 17 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index eb24bb5425..d97fa1c4f1 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -2096,12 +2096,42 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
m_TensorsInfo[node.input(1)].m_dtype,
onnx::TensorProto::INT64); //shape
- if(!m_TensorsInfo[node.input(1)].isConstant())
+ TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+
+ std::vector<unsigned int> targetShape;
+ if(m_TensorsInfo[node.input(1)].isConstant())
{
- throw ParseException(fmt::format("Shape '{}' should be constant in Reshape layer '{}' {}",
- node.input(1),
- node.name(),
- CHECK_LOCATION().AsString()));
+ unsigned int dims = static_cast<unsigned int>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
+ targetShape.reserve(dims);
+
+ for(uint i = 0; i < dims; i++)
+ {
+ int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
+ targetShape[i]= static_cast<unsigned int>(val);
+ }
+ }
+ else
+ {
+ // The parser only supports shape (batch, -1) or (-1) for non-constant shape input.
+ unsigned int dims = m_TensorsInfo[node.input(1)].m_info->GetNumDimensions();
+ TensorShape shapes = m_TensorsInfo[node.input(1)].m_info->GetShape();
+ if (dims != 1 || shapes[0] > 2)
+ {
+ throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}",
+ node.input(1),
+ node.name(),
+ CHECK_LOCATION().AsString()));
+ }
+
+ unsigned int numInputElements = m_TensorsInfo[node.input(0)].m_info->GetNumElements();
+ if (shapes[0] == 1)
+ {
+ targetShape = { numInputElements };
+ }
+ else if (shapes[0] == 2)
+ {
+ targetShape = { inputShape[0] , numInputElements / inputShape[0] };
+ }
}
if(m_TensorsInfo[node.input(0)].isConstant())
@@ -2116,20 +2146,11 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
}
else
{
- TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
-
if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
{
- uint64_t dims = static_cast<uint64_t>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
- TensorShape targetShape{static_cast<unsigned int>(dims), 1};
-
- for(uint i = 0; i < dims; i++)
- {
- int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
- targetShape[i]= static_cast<unsigned int>(val);
- }
-
- auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0));
+ auto outInfo = ComputeReshapeInfo(
+ TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
+ inputShape, node.output(0));
m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
}
diff --git a/src/armnnOnnxParser/test/Reshape.cpp b/src/armnnOnnxParser/test/Reshape.cpp
index e9bcd278cf..97198761e5 100644
--- a/src/armnnOnnxParser/test/Reshape.cpp
+++ b/src/armnnOnnxParser/test/Reshape.cpp
@@ -5,6 +5,7 @@
#include "armnnOnnxParser/IOnnxParser.hpp"
#include "ParserPrototxtFixture.hpp"
+#include "OnnxParserTestUtils.hpp"
TEST_SUITE("OnnxParser_Reshape")
{
@@ -211,4 +212,257 @@ TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape")
CHECK_THROWS_AS(Setup(), armnn::ParseException);
}
+struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+ ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape,
+ const std::vector<int>& shapeInputShape,
+ const std::vector<int>& outputShape,
+ const std::string& shape)
+ {
+ m_Prototext = R"(
+ ir_version: 3
+ producer_name: "onnx-example"
+ graph {
+ name: "ReshapeGrapn"
+ input {
+ name: "Input"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
+ }
+ }
+ }
+ }
+ input {
+ name: "Shape"
+ type {
+ tensor_type {
+ elem_type: 7
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
+ }
+ }
+ }
+ }
+ node {
+ input: "Input"
+ input: "Shape"
+ output: "Output"
+ name: "reshape"
+ op_type: "Reshape"
+ }
+ initializer {
+ dims: 2
+ data_type: 7
+ )" + shape + R"(
+ name: "Shape"
+ }
+ output {
+ name: "Output"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
+ }
+ }
+ }
+ }
+ }
+ opset_import {
+ version: 7
+ })";
+ }
+};
+
+struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture
+{
+ ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1")
+ {
+ Setup();
+ }
+};
+
+struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture
+{
+ ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
+ { 2 },
+ { 2, 6 },
+ "int64_data: -1 int64_data: 6")
+ {
+ Setup();
+ }
+};
+
+struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture
+{
+ ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
+ { 3 },
+ { 3, 1, 4 },
+ "int64_data: 3 int64_data: -1 int64_data: 4")
+ {
+ Setup();
+ }
+};
+
+struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture
+{
+ ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture(
+ { 2, 3, 1, 2 },
+ { 4 },
+ { 3, 1, 2, 2 },
+ "int64_data: 3 int64_data: 1 int64_data: 2 int64_data: -1")
+ {
+ Setup();
+ }
+};
+
+TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest")
+{
+ RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest")
+{
+ RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest")
+{
+ RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest")
+{
+ RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
+}
+
+struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+ ReshapeNonConstShapeFixture(const std::vector<int>& inputShape,
+ const std::vector<int>& shapeInputShape,
+ const std::vector<int>& outputShape)
+ {
+ m_Prototext = R"(
+ ir_version: 3
+ producer_name: "onnx-example"
+ graph {
+ name: "ReshapeGrapn"
+ input {
+ name: "Input"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
+ }
+ }
+ }
+ }
+ input {
+ name: "Shape"
+ type {
+ tensor_type {
+ elem_type: 7
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
+ }
+ }
+ }
+ }
+ node {
+ input: "Input"
+ input: "Shape"
+ output: "Output"
+ name: "reshape"
+ op_type: "Reshape"
+ }
+ output {
+ name: "Output"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
+ }
+ }
+ }
+ }
+ }
+ opset_import {
+ version: 7
+ })";
+ }
+};
+
+struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture
+{
+ ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 })
+ {
+ Setup();
+ }
+};
+
+struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture
+{
+ ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 })
+ {
+ Setup();
+ }
+};
+
+struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture
+{
+ ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 })
+ {
+ }
+};
+
+struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture
+{
+ ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 })
+ {
+ }
+};
+
+TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest")
+{
+ RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest")
+{
+ RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
+ 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
+ {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
+ 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}});
+}
+
+TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest")
+{
+ CHECK_THROWS_AS(Setup(), armnn::ParseException);
+}
+
+TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest")
+{
+ CHECK_THROWS_AS(Setup(), armnn::ParseException);
+}
+
}