diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2020-11-25 18:22:57 +0000 |
---|---|---|
committer | Teresa Charlin <teresa.charlinreyes@arm.com> | 2020-11-26 17:32:32 +0000 |
commit | 98427a19b7e820283909d3e4ae00bc9447e461fc (patch) | |
tree | 96eb4b5fc077774213597f58d87d1df91345b59e /delegate/src/Gather.hpp | |
parent | 1c717648a51af9058db90301fba3451845674ee2 (diff) | |
download | armnn-98427a19b7e820283909d3e4ae00bc9447e461fc.tar.gz |
IVGCVSW-5384 TfLiteDelegate: Implement the Gather operator
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Iaf2112363d2b191327711d8e083fee2a751c35c5
Diffstat (limited to 'delegate/src/Gather.hpp')
-rw-r--r-- | delegate/src/Gather.hpp | 96 |
1 files changed, 80 insertions, 16 deletions
diff --git a/delegate/src/Gather.hpp b/delegate/src/Gather.hpp index 98d8dc9656..9ed0fe15c1 100644 --- a/delegate/src/Gather.hpp +++ b/delegate/src/Gather.hpp @@ -5,29 +5,93 @@ #pragma once -#include <armnn/utility/IgnoreUnused.hpp> - -#include <tensorflow/lite/builtin_ops.h> -#include <tensorflow/lite/c/builtin_op_data.h> -#include <tensorflow/lite/c/common.h> -#include <tensorflow/lite/minimal_logging.h> +#include "DelegateUtils.hpp" +#include <algorithm> +#include <iterator> +#include <string> +#include <vector> namespace armnnDelegate { - TfLiteStatus VisitGatherOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode, int nodeIndex, - int32_t gatherOperatorCode) + int32_t operatorCode) { - armnn::IgnoreUnused(delegateData, - tfLiteContext, - tfLiteNode, - nodeIndex, - gatherOperatorCode); + TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex)); + TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); - return kTfLiteError; -} + const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors; + + const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]]; + if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + const TfLiteTensor& tfLiteIndicesTensor = tfLiteTensors[tfLiteNode->inputs->data[1]]; + if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]]; + if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + auto* gatherParameters = reinterpret_cast<TfLiteGatherParams*>(tfLiteNode->builtin_data); + auto axis = gatherParameters->axis; -} // namespace armnnDelegate + const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); + const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteIndicesTensor); + const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); + armnn::GatherDescriptor gatherDescriptor; + gatherDescriptor.m_Axis = axis; + + auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions()); + auto indicesDimensions = indicesTensorInfo.GetNumDimensions(); + auto outputDimensions = outputTensorInfo.GetNumDimensions(); + if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0))) + { + TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext, + "TfLiteArmnnDelegate: Operation has invalid axis: %d. It is out of bounds [-%d, %d))", + axis, inputDimensions, inputDimensions); + return kTfLiteError; + } + if (outputDimensions != static_cast<unsigned int>(inputDimensions) + indicesDimensions - 1) + { + TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext, + "Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor", + outputDimensions, inputDimensions, indicesDimensions); + return kTfLiteError; + } + + if (!delegateData.m_Network) + { + // Check if supported + bool isSupported = false; + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsGatherSupported, + delegateData.m_Backends, + isSupported, + inputTensorInfo, + indicesTensorInfo, + outputTensorInfo, + gatherDescriptor); + return isSupported ? kTfLiteOk : kTfLiteError; + } + + armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherLayer(gatherDescriptor); + ARMNN_ASSERT(layer != nullptr); + + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + Connect(layer, tfLiteNode, delegateData); + + return kTfLiteOk; +} +} // namespace armnnDelegate
\ No newline at end of file |