aboutsummaryrefslogtreecommitdiff
path: root/delegate/opaque/src/Comparison.hpp
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2023-04-26 11:42:46 +0100
committerMatthew Sloyan <matthew.sloyan@arm.com>2023-04-26 11:36:49 +0000
commit2b04ec3b94da152281fbbc69f8539378589b1f56 (patch)
tree13fd9f3a8ca44cf4f3a53ccf3f44960cfe627475 /delegate/opaque/src/Comparison.hpp
parentf2dffdb00bdf3108ebda6aaa142249d208f0c507 (diff)
downloadarmnn-2b04ec3b94da152281fbbc69f8539378589b1f56.tar.gz
IVGCVSW-7579 IVGCVSW-7581 IVGCVSW-7583 Implement Comparison, Concat and Mean in Opaque Delegate
* Removed input slot check from Connect function as number of TFLite and Arm NN inputs can differ. * Moved SetupConcatViewOrigin function to DelegateUtils.hpp * Simplified validation checks in VistConvolution functions as IsValid and IsDynamic were already being called. Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: I858dbe4b643f9d350d9c38ea255ce5effbda4612
Diffstat (limited to 'delegate/opaque/src/Comparison.hpp')
-rw-r--r--delegate/opaque/src/Comparison.hpp141
1 files changed, 141 insertions, 0 deletions
diff --git a/delegate/opaque/src/Comparison.hpp b/delegate/opaque/src/Comparison.hpp
index e16969768e..046be83094 100644
--- a/delegate/opaque/src/Comparison.hpp
+++ b/delegate/opaque/src/Comparison.hpp
@@ -2,3 +2,144 @@
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
+#pragma once
+
+#include <OpaqueDelegateUtils.hpp>
+
+#include <tensorflow/lite/builtin_ops.h>
+#include <tensorflow/lite/c/builtin_op_data.h>
+#include <tensorflow/lite/c/common.h>
+#include <tensorflow/lite/minimal_logging.h>
+
+namespace armnnOpaqueDelegate
+{
+
+TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
+ TfLiteOpaqueContext* tfLiteContext,
+ TfLiteOpaqueNode* tfLiteNode,
+ int nodeIndex,
+ int32_t tfLiteComparisonOperatorCode)
+{
+ TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
+ TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
+
+ // Gather input indices and use to get input tensor.
+ int numInputs = 0;
+ const int* inputTensors;
+ if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
+ nodeIndex);
+ return kTfLiteError;
+ }
+
+ // Use input indices to get input tensors.
+ const TfLiteOpaqueTensor* tfLiteInputTensor0 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
+ if (!IsValid(tfLiteContext, tfLiteInputTensor0, tfLiteComparisonOperatorCode, nodeIndex))
+ {
+ return kTfLiteError;
+ }
+
+ const TfLiteOpaqueTensor* tfLiteInputTensor1 = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
+ if (!IsValid(tfLiteContext, tfLiteInputTensor1, tfLiteComparisonOperatorCode, nodeIndex))
+ {
+ return kTfLiteError;
+ }
+
+ // Gather output indices and use to get output tensors.
+ int numOutputs = 0;
+ const int* outputTensors;
+ if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
+ nodeIndex);
+ return kTfLiteError;
+ }
+
+ const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
+ if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteComparisonOperatorCode, nodeIndex))
+ {
+ return kTfLiteError;
+ }
+
+ armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor0);
+ armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor1);
+ const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
+
+ // Check if we need to expand the dims of the input tensor infos.
+ // This is required for a few of the backends.
+ if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
+ {
+ ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
+ }
+
+ armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
+ switch(tfLiteComparisonOperatorCode)
+ {
+ case kTfLiteBuiltinEqual:
+ comparisonOperation = armnn::ComparisonOperation::Equal;
+ break;
+ case kTfLiteBuiltinGreater:
+ comparisonOperation = armnn::ComparisonOperation::Greater;
+ break;
+ case kTfLiteBuiltinGreaterEqual:
+ comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
+ break;
+ case kTfLiteBuiltinLess:
+ comparisonOperation = armnn::ComparisonOperation::Less;
+ break;
+ case kTfLiteBuiltinLessEqual:
+ comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
+ break;
+ case kTfLiteBuiltinNotEqual:
+ comparisonOperation = armnn::ComparisonOperation::NotEqual;
+ break;
+ default:
+ return kTfLiteError;
+ }
+
+ armnn::ComparisonDescriptor descriptor(comparisonOperation);
+ bool isSupported = false;
+ armnn::BackendId setBackend;
+ auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
+ {
+ FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("COMPARISON",
+ tfLiteContext,
+ IsComparisonSupported,
+ delegateData.m_Backends,
+ isSupported,
+ setBackend,
+ inputTensorInfo0,
+ inputTensorInfo1,
+ outputTensorInfo,
+ descriptor);
+ };
+
+ if (!delegateData.m_Network)
+ {
+ validateFunc(outputTensorInfo, isSupported);
+ return isSupported ? kTfLiteOk : kTfLiteError;
+ }
+
+ armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor);
+ comparisonLayer->SetBackendId(setBackend);
+ ARMNN_ASSERT(comparisonLayer != nullptr);
+
+ armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
+ outputSlot.SetTensorInfo(outputTensorInfo);
+
+ // try to connect the Constant Inputs if there are any
+ if(ProcessInputs(comparisonLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
+ {
+ return kTfLiteError;
+ }
+
+ return Connect(comparisonLayer, tfLiteContext, tfLiteNode, delegateData);
+}
+
+} // namespace armnnDelegate