From 4b536e323abd09da9630502a8fb7d0be50e1ad45 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Mon, 18 Oct 2021 12:35:19 +0100 Subject: IVGCVSW-6451 Add support for Reshape when the target shape is dynamic and batch size is unknown to ONNX parser Signed-off-by: Narumol Prangnawarat Change-Id: I46b2daccce9e1a21d9d0550ac4126d2c79dbd37b --- src/armnnOnnxParser/OnnxParser.cpp | 55 +++++--- src/armnnOnnxParser/test/Reshape.cpp | 254 +++++++++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+), 17 deletions(-) diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index eb24bb5425..d97fa1c4f1 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -2096,12 +2096,42 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) m_TensorsInfo[node.input(1)].m_dtype, onnx::TensorProto::INT64); //shape - if(!m_TensorsInfo[node.input(1)].isConstant()) + TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + + std::vector targetShape; + if(m_TensorsInfo[node.input(1)].isConstant()) { - throw ParseException(fmt::format("Shape '{}' should be constant in Reshape layer '{}' {}", - node.input(1), - node.name(), - CHECK_LOCATION().AsString())); + unsigned int dims = static_cast(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size()); + targetShape.reserve(dims); + + for(uint i = 0; i < dims; i++) + { + int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast(i))); + targetShape[i]= static_cast(val); + } + } + else + { + // The parser only supports shape (batch, -1) or (-1) for non-constant shape input. + unsigned int dims = m_TensorsInfo[node.input(1)].m_info->GetNumDimensions(); + TensorShape shapes = m_TensorsInfo[node.input(1)].m_info->GetShape(); + if (dims != 1 || shapes[0] > 2) + { + throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}", + node.input(1), + node.name(), + CHECK_LOCATION().AsString())); + } + + unsigned int numInputElements = m_TensorsInfo[node.input(0)].m_info->GetNumElements(); + if (shapes[0] == 1) + { + targetShape = { numInputElements }; + } + else if (shapes[0] == 2) + { + targetShape = { inputShape[0] , numInputElements / inputShape[0] }; + } } if(m_TensorsInfo[node.input(0)].isConstant()) @@ -2116,20 +2146,11 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) } else { - TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); - if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr) { - uint64_t dims = static_cast(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size()); - TensorShape targetShape{static_cast(dims), 1}; - - for(uint i = 0; i < dims; i++) - { - int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast(i))); - targetShape[i]= static_cast(val); - } - - auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0)); + auto outInfo = ComputeReshapeInfo( + TensorShape(static_cast(targetShape.size()), targetShape.data()), + inputShape, node.output(0)); m_TensorsInfo[node.output(0)].m_info = std::make_unique(outInfo); } diff --git a/src/armnnOnnxParser/test/Reshape.cpp b/src/armnnOnnxParser/test/Reshape.cpp index e9bcd278cf..97198761e5 100644 --- a/src/armnnOnnxParser/test/Reshape.cpp +++ b/src/armnnOnnxParser/test/Reshape.cpp @@ -5,6 +5,7 @@ #include "armnnOnnxParser/IOnnxParser.hpp" #include "ParserPrototxtFixture.hpp" +#include "OnnxParserTestUtils.hpp" TEST_SUITE("OnnxParser_Reshape") { @@ -211,4 +212,257 @@ TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape") CHECK_THROWS_AS(Setup(), armnn::ParseException); } +struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture +{ + ReshapeNegativeReshapeFixture(const std::vector& inputShape, + const std::vector& shapeInputShape, + const std::vector& outputShape, + const std::string& shape) + { + m_Prototext = R"( + ir_version: 3 + producer_name: "onnx-example" + graph { + name: "ReshapeGrapn" + input { + name: "Input" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( + } + } + } + } + input { + name: "Shape" + type { + tensor_type { + elem_type: 7 + shape { + )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"( + } + } + } + } + node { + input: "Input" + input: "Shape" + output: "Output" + name: "reshape" + op_type: "Reshape" + } + initializer { + dims: 2 + data_type: 7 + )" + shape + R"( + name: "Shape" + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( + } + } + } + } + } + opset_import { + version: 7 + })"; + } +}; + +struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture +{ + ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1") + { + Setup(); + } +}; + +struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture +{ + ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 }, + { 2 }, + { 2, 6 }, + "int64_data: -1 int64_data: 6") + { + Setup(); + } +}; + +struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture +{ + ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 }, + { 3 }, + { 3, 1, 4 }, + "int64_data: 3 int64_data: -1 int64_data: 4") + { + Setup(); + } +}; + +struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture +{ + ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture( + { 2, 3, 1, 2 }, + { 4 }, + { 3, 1, 2, 2 }, + "int64_data: 3 int64_data: 1 int64_data: 2 int64_data: -1") + { + Setup(); + } +}; + +TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest") +{ + RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest") +{ + RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest") +{ + RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest") +{ + RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); +} + +struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture +{ + ReshapeNonConstShapeFixture(const std::vector& inputShape, + const std::vector& shapeInputShape, + const std::vector& outputShape) + { + m_Prototext = R"( + ir_version: 3 + producer_name: "onnx-example" + graph { + name: "ReshapeGrapn" + input { + name: "Input" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( + } + } + } + } + input { + name: "Shape" + type { + tensor_type { + elem_type: 7 + shape { + )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"( + } + } + } + } + node { + input: "Input" + input: "Shape" + output: "Output" + name: "reshape" + op_type: "Reshape" + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( + } + } + } + } + } + opset_import { + version: 7 + })"; + } +}; + +struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture +{ + ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }) + { + Setup(); + } +}; + +struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture +{ + ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 }) + { + Setup(); + } +}; + +struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture +{ + ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 }) + { + } +}; + +struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture +{ + ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 }) + { + } +}; + +TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest") +{ + RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest") +{ + RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}); +} + +TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest") +{ + CHECK_THROWS_AS(Setup(), armnn::ParseException); +} + +TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest") +{ + CHECK_THROWS_AS(Setup(), armnn::ParseException); +} + } -- cgit v1.2.1