diff options
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index f165df9e14..5ec99ede74 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -441,10 +441,26 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor, CHECK_LOCATION().AsString())); } - auto targetNumElements = armnn::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(), - -1, std::multiplies<int32_t>())); + auto targetNumElements = armnn::numeric_cast<unsigned int>( + std::accumulate(targetDims.begin(), targetDims.end(), -1, std::multiplies<int32_t>())); auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim)); - outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements; + if (targetNumElements == 0) + { + if (inShape.GetNumElements() == 0) + { + outDims[stretchIndex] = 0; + } + else + { + throw ParseException( + fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}", + CHECK_LOCATION().AsString())); + } + } + else + { + outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements; + } } TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()}; return TensorInfo(outShape, dataType); |