aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/Redefine.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/Redefine.hpp')
-rw-r--r--delegate/src/Redefine.hpp94
1 files changed, 43 insertions, 51 deletions
diff --git a/delegate/src/Redefine.hpp b/delegate/src/Redefine.hpp
index 91295768d6..e88038362f 100644
--- a/delegate/src/Redefine.hpp
+++ b/delegate/src/Redefine.hpp
@@ -90,62 +90,53 @@ TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
armnn::ReshapeDescriptor reshapeDesc;
-
- // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
- TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
std::vector<int32_t> targetShape;
- bool targetShapeFound = false;
- if (reshapeOptions != nullptr)
+ // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
+ if (numInputs == 2)
{
- // Options might be set without valid data. we need to check the dimensions are in a valid range.
- if (reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
+ // Get shape from the second input tensor
+ const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
+ if (IsDynamicTensor(tfLiteShapeInputTensor))
{
- uint64_t elementCounter = 1;
- for (int i=0; i < reshapeOptions->num_dimensions; ++i)
- {
- targetShape.push_back(reshapeOptions->shape[i]);
- if (reshapeOptions->shape[i] > 0)
- {
- elementCounter = elementCounter * reshapeOptions->shape[i];
- }
- }
- // Check the number of elements match, otherwise fall back to using the second input tensor.
- if (elementCounter <= inputTensorInfo0.GetNumElements())
- {
- targetShapeFound = true;
- }
+ TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
+ "TfLiteArmnnDelegate: Dynamic input tensors are not supported in "
+ "operator #%d node #%d: ",
+ operatorCode, nodeIndex);
+ return kTfLiteError;
}
- }
- if (!targetShapeFound)
- {
- if (numInputs == 2)
+
+ if (tfLiteShapeInputTensor.dims->size != 1)
{
- const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
- if (IsDynamicTensor(tfLiteShapeInputTensor))
- {
- TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
- "TfLiteArmnnDelegate: Dynamic input tensors are not supported in "
- "operator #%d node #%d: ",
- operatorCode, nodeIndex);
- return kTfLiteError;
- }
+ TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
+ "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
+ "operator #%d node #%d: ",
+ operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
- if (tfLiteShapeInputTensor.dims->size != 1)
- {
- TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
- "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
- "operator #%d node #%d: ",
- operatorCode, nodeIndex);
- return kTfLiteError;
- }
+ // Get the shape data out of the input tensor
+ auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
+ auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
+ for (auto i=0; i < shapeTensorNumValues; ++i)
+ {
+ targetShape.push_back(*(shapeTensorDataPtr+i));
+ }
+ }
+ else
+ {
+ // Get shape from the builtin data
+ TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
- // Get the shape data out of the input tensor
- auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
- auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
- for (auto i=0; i < shapeTensorNumValues; ++i)
+ if (reshapeOptions != nullptr)
+ {
+ // Options might be set without valid data. we need to check the dimensions are in a valid range.
+ if (reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
{
- targetShape.push_back(*(shapeTensorDataPtr+i));
+ for (int i=0; i < reshapeOptions->num_dimensions; ++i)
+ {
+ targetShape.push_back(reshapeOptions->shape[i]);
+ }
}
}
else
@@ -170,10 +161,11 @@ TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements())
{
- TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
- "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
- "operator #%d node #%d: ",
- operatorCode, nodeIndex);
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
+ "operator #%d node #%d: ",
+ operatorCode, nodeIndex);
return kTfLiteError;
}