diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index c2be54f5f5..30875d5650 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -3252,6 +3252,7 @@ void TfLiteParserImpl::ParseActivation(size_t subgraphIndex, size_t operatorInde auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } + armnn::TensorInfo TfLiteParserImpl::OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo, const std::vector<int32_t>& targetDimsIn) { @@ -3271,7 +3272,24 @@ armnn::TensorInfo TfLiteParserImpl::OutputShapeOfReshape(const armnn::TensorInfo std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>())); auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim)); - outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + + if (targetNumElements == 0) + { + if (inputTensorInfo.GetNumElements() == 0) + { + outputDims[stretchIndex] = 0; + } + else + { + throw ParseException( + fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}", + CHECK_LOCATION().AsString())); + } + } + else + { + outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + } } TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data()); |