aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp55
1 files changed, 38 insertions, 17 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index eb24bb5425..d97fa1c4f1 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -2096,12 +2096,42 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
m_TensorsInfo[node.input(1)].m_dtype,
onnx::TensorProto::INT64); //shape
- if(!m_TensorsInfo[node.input(1)].isConstant())
+ TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+
+ std::vector<unsigned int> targetShape;
+ if(m_TensorsInfo[node.input(1)].isConstant())
{
- throw ParseException(fmt::format("Shape '{}' should be constant in Reshape layer '{}' {}",
- node.input(1),
- node.name(),
- CHECK_LOCATION().AsString()));
+ unsigned int dims = static_cast<unsigned int>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
+ targetShape.reserve(dims);
+
+ for(uint i = 0; i < dims; i++)
+ {
+ int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
+ targetShape[i]= static_cast<unsigned int>(val);
+ }
+ }
+ else
+ {
+ // The parser only supports shape (batch, -1) or (-1) for non-constant shape input.
+ unsigned int dims = m_TensorsInfo[node.input(1)].m_info->GetNumDimensions();
+ TensorShape shapes = m_TensorsInfo[node.input(1)].m_info->GetShape();
+ if (dims != 1 || shapes[0] > 2)
+ {
+ throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}",
+ node.input(1),
+ node.name(),
+ CHECK_LOCATION().AsString()));
+ }
+
+ unsigned int numInputElements = m_TensorsInfo[node.input(0)].m_info->GetNumElements();
+ if (shapes[0] == 1)
+ {
+ targetShape = { numInputElements };
+ }
+ else if (shapes[0] == 2)
+ {
+ targetShape = { inputShape[0] , numInputElements / inputShape[0] };
+ }
}
if(m_TensorsInfo[node.input(0)].isConstant())
@@ -2116,20 +2146,11 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
}
else
{
- TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
-
if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
{
- uint64_t dims = static_cast<uint64_t>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
- TensorShape targetShape{static_cast<unsigned int>(dims), 1};
-
- for(uint i = 0; i < dims; i++)
- {
- int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
- targetShape[i]= static_cast<unsigned int>(val);
- }
-
- auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0));
+ auto outInfo = ComputeReshapeInfo(
+ TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
+ inputShape, node.output(0));
m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
}