aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp18
1 files changed, 17 insertions, 1 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 6e2c07bf37..f455b1af0a 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -2632,7 +2632,23 @@ armnn::TensorInfo IDeserializer::DeserializerImpl::OutputShapeOfReshape(const ar
std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
- outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ if (targetNumElements == 0)
+ {
+ if (inputTensorInfo.GetNumElements() == 0)
+ {
+ outputDims[stretchIndex] = 0;
+ }
+ else
+ {
+ throw ParseException(
+ fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}",
+ CHECK_LOCATION().AsString()));
+ }
+ }
+ else
+ {
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
}
TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());