aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp22
1 files changed, 19 insertions, 3 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index f165df9e14..5ec99ede74 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -441,10 +441,26 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
CHECK_LOCATION().AsString()));
}
- auto targetNumElements = armnn::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(),
- -1, std::multiplies<int32_t>()));
+ auto targetNumElements = armnn::numeric_cast<unsigned int>(
+ std::accumulate(targetDims.begin(), targetDims.end(), -1, std::multiplies<int32_t>()));
auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim));
- outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
+ if (targetNumElements == 0)
+ {
+ if (inShape.GetNumElements() == 0)
+ {
+ outDims[stretchIndex] = 0;
+ }
+ else
+ {
+ throw ParseException(
+ fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}",
+ CHECK_LOCATION().AsString()));
+ }
+ }
+ else
+ {
+ outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
+ }
}
TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
return TensorInfo(outShape, dataType);