diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-10-18 12:35:19 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2021-10-21 13:14:51 +0000 |
commit | 4b536e323abd09da9630502a8fb7d0be50e1ad45 (patch) | |
tree | 2b819a7673476fbbceac43ca0ed3ff6db492253d /src/armnnOnnxParser/OnnxParser.cpp | |
parent | f437213e4b54f0179129395828e549c02973e02f (diff) | |
download | armnn-4b536e323abd09da9630502a8fb7d0be50e1ad45.tar.gz |
IVGCVSW-6451 Add support for Reshape when the target shape is dynamic
and batch size is unknown to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I46b2daccce9e1a21d9d0550ac4126d2c79dbd37b
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 55 |
1 files changed, 38 insertions, 17 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index eb24bb5425..d97fa1c4f1 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -2096,12 +2096,42 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) m_TensorsInfo[node.input(1)].m_dtype, onnx::TensorProto::INT64); //shape - if(!m_TensorsInfo[node.input(1)].isConstant()) + TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + + std::vector<unsigned int> targetShape; + if(m_TensorsInfo[node.input(1)].isConstant()) { - throw ParseException(fmt::format("Shape '{}' should be constant in Reshape layer '{}' {}", - node.input(1), - node.name(), - CHECK_LOCATION().AsString())); + unsigned int dims = static_cast<unsigned int>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size()); + targetShape.reserve(dims); + + for(uint i = 0; i < dims; i++) + { + int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i))); + targetShape[i]= static_cast<unsigned int>(val); + } + } + else + { + // The parser only supports shape (batch, -1) or (-1) for non-constant shape input. + unsigned int dims = m_TensorsInfo[node.input(1)].m_info->GetNumDimensions(); + TensorShape shapes = m_TensorsInfo[node.input(1)].m_info->GetShape(); + if (dims != 1 || shapes[0] > 2) + { + throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}", + node.input(1), + node.name(), + CHECK_LOCATION().AsString())); + } + + unsigned int numInputElements = m_TensorsInfo[node.input(0)].m_info->GetNumElements(); + if (shapes[0] == 1) + { + targetShape = { numInputElements }; + } + else if (shapes[0] == 2) + { + targetShape = { inputShape[0] , numInputElements / inputShape[0] }; + } } if(m_TensorsInfo[node.input(0)].isConstant()) @@ -2116,20 +2146,11 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) } else { - TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); - if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr) { - uint64_t dims = static_cast<uint64_t>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size()); - TensorShape targetShape{static_cast<unsigned int>(dims), 1}; - - for(uint i = 0; i < dims; i++) - { - int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i))); - targetShape[i]= static_cast<unsigned int>(val); - } - - auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0)); + auto outInfo = ComputeReshapeInfo( + TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()), + inputShape, node.output(0)); m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo); } |