aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTianle Cheng <tianle.cheng@arm.com>2023-10-03 12:01:11 +0100
committerTianle Cheng <tianle.cheng@arm.com>2023-10-03 14:31:21 +0100
commit2077348da176dd5f264bb04a0b500c5901969d09 (patch)
tree6a2e944d16b2693bbc6a666cd3c5432e59265498
parentad323af0e9b47e53d366b85cdf74927f88748d40 (diff)
downloadarmnn-2077348da176dd5f264bb04a0b500c5901969d09.tar.gz
IVGCVSW-7749 DTS: Fix reshape floating point exception
* Updated Opaque Delegate, TfliteParser, OnnxParser, and Deserializer to handle the Zero In Shape edge case Signed-off-by: Tianle Cheng <tianle.cheng@arm.com> Change-Id: I4a0d1e72a66de1fa56de99af9b6730a84e0ff596
-rw-r--r--delegate/classic/src/Redefine.hpp12
-rw-r--r--delegate/common/src/DelegateUtils.hpp11
-rw-r--r--delegate/opaque/src/Redefine.hpp13
-rw-r--r--src/armnnDeserializer/Deserializer.cpp18
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp22
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp20
6 files changed, 90 insertions, 6 deletions
diff --git a/delegate/classic/src/Redefine.hpp b/delegate/classic/src/Redefine.hpp
index 6b10e448e7..c3422a2fb5 100644
--- a/delegate/classic/src/Redefine.hpp
+++ b/delegate/classic/src/Redefine.hpp
@@ -166,6 +166,18 @@ TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
return kTfLiteError;
}
+ // Check the target shape to check if there is zero in the shape.
+ if (std::find(targetShape.begin(), targetShape.end(), 0) != targetShape.end() &&
+ inputTensorInfo0.GetNumElements() != 0)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
+ "TfLiteArmnnDelegate: Input to reshape is a tensor with elements, "
+ "but the requested shape has 0. "
+ "operator #%d node #%d: ",
+ operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+
// Use the data to create the required tensor shape.
if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
{
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp
index a74ed8b549..a2cdc83a64 100644
--- a/delegate/common/src/DelegateUtils.hpp
+++ b/delegate/common/src/DelegateUtils.hpp
@@ -186,7 +186,16 @@ TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>()));
auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim));
- outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+
+ if (targetNumElements == 0)
+ {
+ // To handle the edge case that input and output both have zero elements
+ outputDims[stretchIndex] = 0;
+ }
+ else
+ {
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
}
armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()),
diff --git a/delegate/opaque/src/Redefine.hpp b/delegate/opaque/src/Redefine.hpp
index 5ce7a3dcc1..6319ca7841 100644
--- a/delegate/opaque/src/Redefine.hpp
+++ b/delegate/opaque/src/Redefine.hpp
@@ -201,6 +201,19 @@ TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
return kTfLiteError;
}
+ // Check the target shape to check if there is zero in the shape.
+ if (std::find(targetShape.begin(), targetShape.end(), 0) != targetShape.end() &&
+ inputTensorInfo0.GetNumElements() != 0)
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Input to reshape is a tensor with elements, "
+ "but the requested shape has 0. "
+ "operator #%d node #%d: ",
+ operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+
// Use the data to create the required tensor shape.
if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
{
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 6e2c07bf37..f455b1af0a 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -2632,7 +2632,23 @@ armnn::TensorInfo IDeserializer::DeserializerImpl::OutputShapeOfReshape(const ar
std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
- outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ if (targetNumElements == 0)
+ {
+ if (inputTensorInfo.GetNumElements() == 0)
+ {
+ outputDims[stretchIndex] = 0;
+ }
+ else
+ {
+ throw ParseException(
+ fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}",
+ CHECK_LOCATION().AsString()));
+ }
+ }
+ else
+ {
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
}
TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index f165df9e14..5ec99ede74 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -441,10 +441,26 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
CHECK_LOCATION().AsString()));
}
- auto targetNumElements = armnn::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(),
- -1, std::multiplies<int32_t>()));
+ auto targetNumElements = armnn::numeric_cast<unsigned int>(
+ std::accumulate(targetDims.begin(), targetDims.end(), -1, std::multiplies<int32_t>()));
auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim));
- outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
+ if (targetNumElements == 0)
+ {
+ if (inShape.GetNumElements() == 0)
+ {
+ outDims[stretchIndex] = 0;
+ }
+ else
+ {
+ throw ParseException(
+ fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}",
+ CHECK_LOCATION().AsString()));
+ }
+ }
+ else
+ {
+ outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
+ }
}
TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
return TensorInfo(outShape, dataType);
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index c2be54f5f5..30875d5650 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -3252,6 +3252,7 @@ void TfLiteParserImpl::ParseActivation(size_t subgraphIndex, size_t operatorInde
auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
+
armnn::TensorInfo TfLiteParserImpl::OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
const std::vector<int32_t>& targetDimsIn)
{
@@ -3271,7 +3272,24 @@ armnn::TensorInfo TfLiteParserImpl::OutputShapeOfReshape(const armnn::TensorInfo
std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
- outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+
+ if (targetNumElements == 0)
+ {
+ if (inputTensorInfo.GetNumElements() == 0)
+ {
+ outputDims[stretchIndex] = 0;
+ }
+ else
+ {
+ throw ParseException(
+ fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}",
+ CHECK_LOCATION().AsString()));
+ }
+ }
+ else
+ {
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
}
TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());