diff options
Diffstat (limited to 'delegate/opaque/src')
-rw-r--r-- | delegate/opaque/src/Redefine.hpp | 194 | ||||
-rw-r--r-- | delegate/opaque/src/armnn_delegate.cpp | 6 |
2 files changed, 184 insertions, 16 deletions
diff --git a/delegate/opaque/src/Redefine.hpp b/delegate/opaque/src/Redefine.hpp index 7dd8561de4..dc424cff00 100644 --- a/delegate/opaque/src/Redefine.hpp +++ b/delegate/opaque/src/Redefine.hpp @@ -4,15 +4,7 @@ // #pragma once -#include <armnn/utility/IgnoreUnused.hpp> - -#include "OpaqueDelegateUtils.hpp" - -#include <tensorflow/lite/builtin_ops.h> -#include <tensorflow/lite/c/builtin_op_data.h> -#include <tensorflow/lite/c/common.h> -#include <tensorflow/lite/minimal_logging.h> -#include <numeric> +#include <OpaqueDelegateUtils.hpp> namespace armnnOpaqueDelegate { @@ -62,13 +54,13 @@ TfLiteStatus VisitCastOperator(DelegateData& delegateData, armnn::BackendId setBackend; auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported) { FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("CAST", - tfLiteContext, - IsCastSupported, - delegateData.m_Backends, - isSupported, - setBackend, - inputTensorInfo, - outInfo); + tfLiteContext, + IsCastSupported, + delegateData.m_Backends, + isSupported, + setBackend, + inputTensorInfo, + outInfo); }; // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the @@ -97,4 +89,174 @@ TfLiteStatus VisitCastOperator(DelegateData& delegateData, // Connect return Connect(layer, tfLiteContext, tfLiteNode, delegateData); } + +TfLiteStatus VisitReshapeOperator(DelegateData& delegateData, + TfLiteOpaqueContext* tfLiteContext, + TfLiteOpaqueNode* tfLiteNode, + int nodeIndex, + int32_t operatorCode) +{ + auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode); + + if (numInputs == 2) + { + TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex)); + } + else + { + TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); + } + TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); + + // Gather input indices and use to get input tensor. + const int* inputTensors; + if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ", + nodeIndex); + return kTfLiteError; + } + + const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]); + if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + // Gather output indices and use to get output tensors. + int numOutputs = 0; + const int* outputTensors; + if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ", + nodeIndex); + return kTfLiteError; + } + + const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]); + if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor); + const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true); + + armnn::ReshapeDescriptor reshapeDesc; + std::vector<int32_t> targetShape; + + auto* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode)); + + // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both. + // Options might be set without valid data. we need to check the dimensions are in a valid range. + if (reshapeOptions && reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8) + { + for (int i = 0; i < reshapeOptions->num_dimensions; ++i) + { + targetShape.push_back(reshapeOptions->shape[i]); + } + } + else if (numInputs == 2) + { + // Get shape from the second input tensor + const TfLiteOpaqueTensor* tfLiteShapeInputTensor = + TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]); + if (!IsValid(tfLiteContext, tfLiteShapeInputTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + int32_t numDims = TfLiteOpaqueTensorNumDims(tfLiteShapeInputTensor); + if (numDims != 1) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Target 'shape' input is not a 1D tensor in " + "operator #%d node #%d: Falling back to TfLiteOptions.", + operatorCode, nodeIndex); + } + else + { + // Get the shape data out of the input tensor + auto* shapeTensorDataPtr = static_cast<int32_t*>(TfLiteOpaqueTensorData(tfLiteShapeInputTensor)); + int32_t shapeTensorNumValues = TfLiteOpaqueTensorDim(tfLiteShapeInputTensor, 0); + for (int32_t i = 0; i < shapeTensorNumValues; ++i) + { + targetShape.push_back(shapeTensorDataPtr[i]); + } + } + } + else + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Target shape not defined in reshape parameters or input tensor. " + "At least one method required in operator #%d node #%d: ", + operatorCode, nodeIndex); + return kTfLiteError; + } + + // Use the data to create the required tensor shape. + if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: At most one component of shape can be -1 in: " + "operator #%d node #%d: ", + operatorCode, nodeIndex); + return kTfLiteError; + } + + if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements()) + { + TF_LITE_OPAQUE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnOpaqueDelegate: Reshape, number of elements in output shape does not match input " + "operator #%d node #%d: ", + operatorCode, nodeIndex); + return kTfLiteError; + } + + bool isSupported = false; + armnn::BackendId setBackend; + auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported) + { + FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("RESHAPE", + tfLiteContext, + IsReshapeSupported, + delegateData.m_Backends, + isSupported, + setBackend, + inputTensorInfo0, + outInfo, + reshapeDesc); + }; + + if (!delegateData.m_Network) + { + validateFunc(outputTensorInfo, isSupported); + return isSupported ? kTfLiteOk : kTfLiteError; + } + + armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc); + layer->SetBackendId(setBackend); + ARMNN_ASSERT(layer != nullptr); + + armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); + outputSlot.SetTensorInfo(outputTensorInfo); + + // try to connect the Constant Inputs if there are any + if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk ) + { + return kTfLiteError; + } + + // Connect + return Connect(layer, tfLiteContext, tfLiteNode, delegateData); +} + } diff --git a/delegate/opaque/src/armnn_delegate.cpp b/delegate/opaque/src/armnn_delegate.cpp index c96f75dcb3..2fd8142169 100644 --- a/delegate/opaque/src/armnn_delegate.cpp +++ b/delegate/opaque/src/armnn_delegate.cpp @@ -1002,6 +1002,12 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, tfLiteNode, nodeIndex, kTfLiteBuiltinRelu6); + case kTfLiteBuiltinReshape: + return VisitReshapeOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + kTfLiteBuiltinReshape); case kTfLiteBuiltinResizeNearestNeighbor: return VisitResizeOperator(delegateData, tfLiteContext, |