diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2021-01-19 17:24:21 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2021-01-19 17:24:21 +0000 |
commit | dc032fca290deb39af65050c254a701596b53fa8 (patch) | |
tree | e3957a2651f0fbfe9a13f3ff1d2f092178578257 /delegate/src/ArgMinMax.hpp | |
parent | 97bf84f6e162307fc3e8c53045ef0bc60a3e3289 (diff) | |
download | armnn-dc032fca290deb39af65050c254a701596b53fa8.tar.gz |
IVGCVSW-5399 'TfLiteDelegate: Implement the ArgMinMax operators'
* Added ARG_MIN and ARG_MAX support to armnn_delegate
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Ia000c4b64378e28320164edd4df2902ca13dcda6
Diffstat (limited to 'delegate/src/ArgMinMax.hpp')
-rw-r--r-- | delegate/src/ArgMinMax.hpp | 120 |
1 files changed, 112 insertions, 8 deletions
diff --git a/delegate/src/ArgMinMax.hpp b/delegate/src/ArgMinMax.hpp index 367ef2ed14..090d18ef65 100644 --- a/delegate/src/ArgMinMax.hpp +++ b/delegate/src/ArgMinMax.hpp @@ -5,11 +5,10 @@ #pragma once -#include <armnn/utility/IgnoreUnused.hpp> - #include <tensorflow/lite/builtin_ops.h> #include <tensorflow/lite/c/builtin_op_data.h> #include <tensorflow/lite/c/common.h> +#include <tensorflow/lite/kernels/internal/tensor_ctypes.h> #include <tensorflow/lite/minimal_logging.h> namespace armnnDelegate @@ -21,13 +20,118 @@ TfLiteStatus VisitArgMinMaxOperator(DelegateData& delegateData, int nodeIndex, int32_t argMinMaxOperatorCode) { - armnn::IgnoreUnused(delegateData, - tfLiteContext, - tfLiteNode, - nodeIndex, - argMinMaxOperatorCode); + TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex)); + TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); + + const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors; + const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]]; + if (!IsValid(tfLiteContext, tfLiteInputTensor, argMinMaxOperatorCode, nodeIndex)) + { + return kTfLiteError; + } + + const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]]; + if (!IsValid(tfLiteContext, tfLiteOutputTensor, argMinMaxOperatorCode, nodeIndex)) + { + return kTfLiteError; + } + + const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); + const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); + + // Get const axis value from model and set it to descriptor. + const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]]; + if (!IsValid(tfLiteContext, tfLiteAxisTensor, argMinMaxOperatorCode, nodeIndex)) + { + return kTfLiteError; + } + + armnn::ArgMinMaxDescriptor desc; + // Get the axis value from the input tensor + switch (tfLiteAxisTensor.type) + { + case kTfLiteInt32: + case kTfLiteInt64: + desc.m_Axis = tflite::GetTensorData<int>(&tfLiteAxisTensor)[0]; + break; + default: + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Axis value data type is not supported in operator #%d node #%d: ", + argMinMaxOperatorCode, nodeIndex); + return kTfLiteError; + } + + // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64. + if (argMinMaxOperatorCode == kTfLiteBuiltinArgMax) + { + desc.m_Function = armnn::ArgMinMaxFunction::Max; + auto* argMaxParameters = reinterpret_cast<TfLiteArgMaxParams*>(tfLiteNode->builtin_data); + switch (argMaxParameters->output_type) + { + case kTfLiteInt32: + desc.m_Output_Type = armnn::DataType::Signed32; + break; + case kTfLiteInt64: + desc.m_Output_Type = armnn::DataType::Signed64; + break; + default: + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ", + argMinMaxOperatorCode, nodeIndex); + return kTfLiteError; + } + } + else + { + desc.m_Function = armnn::ArgMinMaxFunction::Min; + auto* argMinParameters = reinterpret_cast<TfLiteArgMinParams*>(tfLiteNode->builtin_data); + switch (argMinParameters->output_type) + { + case kTfLiteInt32: + desc.m_Output_Type = armnn::DataType::Signed32; + break; + case kTfLiteInt64: + desc.m_Output_Type = armnn::DataType::Signed64; + break; + default: + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ", + argMinMaxOperatorCode, nodeIndex); + return kTfLiteError; + } + } + + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsArgMinMaxSupported, + delegateData.m_Backends, + isSupported, + inputTensorInfo, + outInfo, + desc); + }; + + if (!delegateData.m_Network) + { + validateFunc(outputTensorInfo, isSupported); + return isSupported ? kTfLiteOk : kTfLiteError; + } + + // Add an ArgMinMax layer + armnn::IConnectableLayer* layer = delegateData.m_Network->AddArgMinMaxLayer(desc); + ARMNN_ASSERT(layer != nullptr); + + armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); + outputSlot.SetTensorInfo(outputTensorInfo); - return kTfLiteError; + // Connect + return Connect(layer, tfLiteNode, delegateData); } } // namespace armnnDelegate |