aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-10-18 12:35:19 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2021-10-21 13:14:51 +0000
commit4b536e323abd09da9630502a8fb7d0be50e1ad45 (patch)
tree2b819a7673476fbbceac43ca0ed3ff6db492253d /src/armnnOnnxParser/OnnxParser.cpp
parentf437213e4b54f0179129395828e549c02973e02f (diff)
downloadarmnn-4b536e323abd09da9630502a8fb7d0be50e1ad45.tar.gz
IVGCVSW-6451 Add support for Reshape when the target shape is dynamic
and batch size is unknown to ONNX parser Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I46b2daccce9e1a21d9d0550ac4126d2c79dbd37b
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp55
1 files changed, 38 insertions, 17 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index eb24bb5425..d97fa1c4f1 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -2096,12 +2096,42 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
m_TensorsInfo[node.input(1)].m_dtype,
onnx::TensorProto::INT64); //shape
- if(!m_TensorsInfo[node.input(1)].isConstant())
+ TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+
+ std::vector<unsigned int> targetShape;
+ if(m_TensorsInfo[node.input(1)].isConstant())
{
- throw ParseException(fmt::format("Shape '{}' should be constant in Reshape layer '{}' {}",
- node.input(1),
- node.name(),
- CHECK_LOCATION().AsString()));
+ unsigned int dims = static_cast<unsigned int>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
+ targetShape.reserve(dims);
+
+ for(uint i = 0; i < dims; i++)
+ {
+ int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
+ targetShape[i]= static_cast<unsigned int>(val);
+ }
+ }
+ else
+ {
+ // The parser only supports shape (batch, -1) or (-1) for non-constant shape input.
+ unsigned int dims = m_TensorsInfo[node.input(1)].m_info->GetNumDimensions();
+ TensorShape shapes = m_TensorsInfo[node.input(1)].m_info->GetShape();
+ if (dims != 1 || shapes[0] > 2)
+ {
+ throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}",
+ node.input(1),
+ node.name(),
+ CHECK_LOCATION().AsString()));
+ }
+
+ unsigned int numInputElements = m_TensorsInfo[node.input(0)].m_info->GetNumElements();
+ if (shapes[0] == 1)
+ {
+ targetShape = { numInputElements };
+ }
+ else if (shapes[0] == 2)
+ {
+ targetShape = { inputShape[0] , numInputElements / inputShape[0] };
+ }
}
if(m_TensorsInfo[node.input(0)].isConstant())
@@ -2116,20 +2146,11 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
}
else
{
- TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
-
if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
{
- uint64_t dims = static_cast<uint64_t>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
- TensorShape targetShape{static_cast<unsigned int>(dims), 1};
-
- for(uint i = 0; i < dims; i++)
- {
- int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
- targetShape[i]= static_cast<unsigned int>(val);
- }
-
- auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0));
+ auto outInfo = ComputeReshapeInfo(
+ TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
+ inputShape, node.output(0));
m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
}