diff options
Diffstat (limited to 'delegate')
-rw-r--r-- | delegate/classic/src/Redefine.hpp | 12 | ||||
-rw-r--r-- | delegate/common/src/DelegateUtils.hpp | 11 | ||||
-rw-r--r-- | delegate/opaque/src/Redefine.hpp | 13 |
3 files changed, 35 insertions, 1 deletions
diff --git a/delegate/classic/src/Redefine.hpp b/delegate/classic/src/Redefine.hpp index 6b10e448e7..c3422a2fb5 100644 --- a/delegate/classic/src/Redefine.hpp +++ b/delegate/classic/src/Redefine.hpp @@ -166,6 +166,18 @@ TfLiteStatus VisitReshapeOperator(DelegateData& delegateData, return kTfLiteError; } + // Check the target shape to check if there is zero in the shape. + if (std::find(targetShape.begin(), targetShape.end(), 0) != targetShape.end() && + inputTensorInfo0.GetNumElements() != 0) + { + TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext, + "TfLiteArmnnDelegate: Input to reshape is a tensor with elements, " + "but the requested shape has 0. " + "operator #%d node #%d: ", + operatorCode, nodeIndex); + return kTfLiteError; + } + // Use the data to create the required tensor shape. if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk) { diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp index a74ed8b549..a2cdc83a64 100644 --- a/delegate/common/src/DelegateUtils.hpp +++ b/delegate/common/src/DelegateUtils.hpp @@ -186,7 +186,16 @@ TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo, std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>())); auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim)); - outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + + if (targetNumElements == 0) + { + // To handle the edge case that input and output both have zero elements + outputDims[stretchIndex] = 0; + } + else + { + outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + } } armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()), diff --git a/delegate/opaque/src/Redefine.hpp b/delegate/opaque/src/Redefine.hpp index 5ce7a3dcc1..6319ca7841 100644 --- a/delegate/opaque/src/Redefine.hpp +++ b/delegate/opaque/src/Redefine.hpp @@ -201,6 +201,19 @@ TfLiteStatus VisitReshapeOperator(DelegateData& delegateData, return kTfLiteError; } + // Check the target shape to check if there is zero in the shape. + if (std::find(targetShape.begin(), targetShape.end(), 0) != targetShape.end() && + inputTensorInfo0.GetNumElements() != 0) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Input to reshape is a tensor with elements, " + "but the requested shape has 0. " + "operator #%d node #%d: ", + operatorCode, nodeIndex); + return kTfLiteError; + } + // Use the data to create the required tensor shape. if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk) { |