From 8b9858d891439fd1b0710e5d245e2116a3b88d30 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Mon, 9 Nov 2020 08:26:22 +0000 Subject: IVGCVSW-5380 'TfLiteDelegate: Implement the Comparison operators' * Implemented Comparison Operators * Added unit tests Signed-off-by: Sadik Armagan Change-Id: Icdc0f7c6a286a8364a2770b26d15e8958291dc2b --- delegate/src/Comparison.hpp | 107 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 2 deletions(-) (limited to 'delegate/src/Comparison.hpp') diff --git a/delegate/src/Comparison.hpp b/delegate/src/Comparison.hpp index 19d8de10e1..f787a22090 100644 --- a/delegate/src/Comparison.hpp +++ b/delegate/src/Comparison.hpp @@ -5,6 +5,8 @@ #pragma once +#include "DelegateUtils.hpp" + #include #include #include @@ -17,9 +19,110 @@ TfLiteStatus VisitComparisonOperator(DelegateData& delegateData, TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode, int nodeIndex, - int32_t comparisonOperatorCode) + int32_t tfLiteComparisonOperatorCode) { - return kTfLiteError; + TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex)); + TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); + + const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors; + const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]]; + if (IsDynamicTensor(tfLiteInputTensor0)) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ", + tfLiteComparisonOperatorCode, nodeIndex); + return kTfLiteError; + } + + const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]]; + if (IsDynamicTensor(tfLiteInputTensor1)) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ", + tfLiteComparisonOperatorCode, nodeIndex); + return kTfLiteError; + } + + const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]]; + if (IsDynamicTensor(tfLiteOutputTensor)) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ", + tfLiteComparisonOperatorCode, nodeIndex); + return kTfLiteError; + } + + const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0); + const armnn::TensorInfo& inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1); + const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); + + armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal; + switch(tfLiteComparisonOperatorCode) + { + case kTfLiteBuiltinEqual: + comparisonOperation = armnn::ComparisonOperation::Equal; + break; + case kTfLiteBuiltinGreater: + comparisonOperation = armnn::ComparisonOperation::Greater; + break; + case kTfLiteBuiltinGreaterEqual: + comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual; + break; + case kTfLiteBuiltinLess: + comparisonOperation = armnn::ComparisonOperation::Less; + break; + case kTfLiteBuiltinLessEqual: + comparisonOperation = armnn::ComparisonOperation::LessOrEqual; + break; + case kTfLiteBuiltinNotEqual: + comparisonOperation = armnn::ComparisonOperation::NotEqual; + break; + default: + return kTfLiteError; + } + + armnn::ComparisonDescriptor descriptor(comparisonOperation); + bool isSupported = false; + + auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + tfLiteContext, + IsComparisonSupported, + delegateData.m_Backends, + isSupported, + inputTensorInfo0, + inputTensorInfo1, + outputTensorInfo, + descriptor); + }; + + if (!delegateData.m_Network) + { + validateFunc(outputTensorInfo, isSupported); + return isSupported ? kTfLiteOk : kTfLiteError; + } + + armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor); + ARMNN_ASSERT(comparisonLayer != nullptr); + + armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0); + outputSlot.SetTensorInfo(outputTensorInfo); + + auto reshapeLayer = BroadcastTensor(inputTensorInfo0, + inputTensorInfo1, + comparisonLayer, + tfLiteContext, + tfLiteNode, + delegateData); + if (!reshapeLayer) + { + return kTfLiteError; + } + return kTfLiteOk; } } // namespace armnnDelegate -- cgit v1.2.1