From 98427a19b7e820283909d3e4ae00bc9447e461fc Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Wed, 25 Nov 2020 18:22:57 +0000 Subject: IVGCVSW-5384 TfLiteDelegate: Implement the Gather operator Signed-off-by: Teresa Charlin Change-Id: Iaf2112363d2b191327711d8e083fee2a751c35c5 --- delegate/src/Gather.hpp | 96 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 80 insertions(+), 16 deletions(-) (limited to 'delegate/src/Gather.hpp') diff --git a/delegate/src/Gather.hpp b/delegate/src/Gather.hpp index 98d8dc9656..9ed0fe15c1 100644 --- a/delegate/src/Gather.hpp +++ b/delegate/src/Gather.hpp @@ -5,29 +5,93 @@ #pragma once -#include - -#include -#include -#include -#include +#include "DelegateUtils.hpp" +#include +#include +#include +#include namespace armnnDelegate { - TfLiteStatus VisitGatherOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode, int nodeIndex, - int32_t gatherOperatorCode) + int32_t operatorCode) { - armnn::IgnoreUnused(delegateData, - tfLiteContext, - tfLiteNode, - nodeIndex, - gatherOperatorCode); + TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex)); + TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); - return kTfLiteError; -} + const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors; + + const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]]; + if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + const TfLiteTensor& tfLiteIndicesTensor = tfLiteTensors[tfLiteNode->inputs->data[1]]; + if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]]; + if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex)) + { + return kTfLiteError; + } + + auto* gatherParameters = reinterpret_cast(tfLiteNode->builtin_data); + auto axis = gatherParameters->axis; -} // namespace armnnDelegate + const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); + const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteIndicesTensor); + const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); + armnn::GatherDescriptor gatherDescriptor; + gatherDescriptor.m_Axis = axis; + + auto inputDimensions = static_cast(inputTensorInfo.GetNumDimensions()); + auto indicesDimensions = indicesTensorInfo.GetNumDimensions(); + auto outputDimensions = outputTensorInfo.GetNumDimensions(); + if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0))) + { + TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext, + "TfLiteArmnnDelegate: Operation has invalid axis: %d. It is out of bounds [-%d, %d))", + axis, inputDimensions, inputDimensions); + return kTfLiteError; + } + if (outputDimensions != static_cast(inputDimensions) + indicesDimensions - 1) + { + TF_LITE_MAYBE_KERNEL_LOG( tfLiteContext, + "Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor", + outputDimensions, inputDimensions, indicesDimensions); + return kTfLiteError; + } + + if (!delegateData.m_Network) + { + // Check if supported + bool isSupported = false; + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsGatherSupported, + delegateData.m_Backends, + isSupported, + inputTensorInfo, + indicesTensorInfo, + outputTensorInfo, + gatherDescriptor); + return isSupported ? kTfLiteOk : kTfLiteError; + } + + armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherLayer(gatherDescriptor); + ARMNN_ASSERT(layer != nullptr); + + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + Connect(layer, tfLiteNode, delegateData); + + return kTfLiteOk; +} +} // namespace armnnDelegate \ No newline at end of file -- cgit v1.2.1