diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 6e2c07bf37..f455b1af0a 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -2632,7 +2632,23 @@ armnn::TensorInfo IDeserializer::DeserializerImpl::OutputShapeOfReshape(const ar std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>())); auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim)); - outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + if (targetNumElements == 0) + { + if (inputTensorInfo.GetNumElements() == 0) + { + outputDims[stretchIndex] = 0; + } + else + { + throw ParseException( + fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}", + CHECK_LOCATION().AsString())); + } + } + else + { + outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + } } TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data()); |