aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Mcloughlin <john.mcloughlin@arm.com>2023-04-26 20:14:47 +0100
committerJohn Mcloughlin <john.mcloughlin@arm.com>2023-04-26 20:14:47 +0100
commit559d9097ac8e95d53d507074dbb1ba69d42664a6 (patch)
tree30de17c6ff453764758d05f7ff6e5b09817b89f6
parent81b66f3aeea1d0e788b0ce2894a58fedc763470b (diff)
downloadarmnn-559d9097ac8e95d53d507074dbb1ba69d42664a6.tar.gz
IVGCVSW-7575 Implement opaque delegate for ArgMinMax operator
* Added Opaque ArgMinMax and associated test cases Signed-off-by: John Mcloughlin <john.mcloughlin@arm.com> Change-Id: I098b94cc35707370a0bbc7456bfdd48bb47432f0
-rw-r--r--delegate/CMakeLists.txt2
-rw-r--r--delegate/opaque/CMakeLists.txt1
-rw-r--r--delegate/opaque/src/ArgMinMax.hpp160
-rw-r--r--delegate/opaque/src/armnn_delegate.cpp12
4 files changed, 175 insertions, 0 deletions
diff --git a/delegate/CMakeLists.txt b/delegate/CMakeLists.txt
index 7dc89d79cf..3c5033af5b 100644
--- a/delegate/CMakeLists.txt
+++ b/delegate/CMakeLists.txt
@@ -259,6 +259,8 @@ if(BUILD_UNIT_TESTS)
common/src/test/DelegateTestInterpreterUtils.hpp
opaque/src/test/ArmnnOpaqueDelegateTest.cpp
opaque/src/test/DelegateTestInterpreter.cpp
+ test/ArgMinMaxTest.cpp
+ test/ArgMinMaxTestHelper.hpp
test/BatchSpaceTest.cpp
test/BatchSpaceTestHelper.hpp
test/CastTest.cpp
diff --git a/delegate/opaque/CMakeLists.txt b/delegate/opaque/CMakeLists.txt
index 958dcf6014..156f79b0a9 100644
--- a/delegate/opaque/CMakeLists.txt
+++ b/delegate/opaque/CMakeLists.txt
@@ -7,6 +7,7 @@ set(armnnOpaqueDelegateObject_sources)
list(APPEND armnnOpaqueDelegateObject_sources
include/armnn_delegate.hpp
include/Version.hpp
+ src/ArgMinMax.hpp
src/armnn_delegate.cpp
src/BatchSpace.hpp
src/Convolution.hpp
diff --git a/delegate/opaque/src/ArgMinMax.hpp b/delegate/opaque/src/ArgMinMax.hpp
index e16969768e..7dfd89f57b 100644
--- a/delegate/opaque/src/ArgMinMax.hpp
+++ b/delegate/opaque/src/ArgMinMax.hpp
@@ -2,3 +2,163 @@
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
+#include <OpaqueDelegateUtils.hpp>
+
+#include <tensorflow/lite/builtin_ops.h>
+#include <tensorflow/lite/c/builtin_op_data.h>
+#include <tensorflow/lite/c/common.h>
+#include <tensorflow/lite/minimal_logging.h>
+
+namespace armnnOpaqueDelegate
+{
+TfLiteStatus VisitArgMinMaxOperator(DelegateData& delegateData,
+ TfLiteOpaqueContext* tfLiteContext,
+ TfLiteOpaqueNode* tfLiteNode,
+ int nodeIndex,
+ int32_t argMinMaxOperatorCode)
+{
+ TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
+ TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
+
+ // Gather input indices and use to get input tensor.
+ auto numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
+ const int* inputTensors;
+ if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Unable to gather input tensor indices from node #%d: ",
+ nodeIndex);
+ return kTfLiteError;
+ }
+
+ const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[0]);
+ if (!IsValid(tfLiteContext, tfLiteInputTensor, argMinMaxOperatorCode, nodeIndex))
+ {
+ return kTfLiteError;
+ }
+
+ // Use input indices to get filter tensor.
+ const TfLiteOpaqueTensor* tfLiteAxisTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[1]);
+ if(!IsValid(tfLiteAxisTensor))
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Invalid filter tensor in operator #%d node #%d: ",
+ argMinMaxOperatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+
+ // Gather output indices and use to get output tensors.
+ int numOutputs = 0;
+ const int* outputTensors;
+ if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Unable to gather output tensor indices from node #%d: ",
+ nodeIndex);
+ return kTfLiteError;
+ }
+
+ const TfLiteOpaqueTensor* tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[0]);
+ if (!IsValid(tfLiteContext, tfLiteInputTensor, argMinMaxOperatorCode, nodeIndex))
+ {
+ return kTfLiteError;
+ }
+
+ const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
+ const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
+
+ // Get const axis value from model and set it to descriptor.
+ if (!IsValid(tfLiteContext, tfLiteAxisTensor, argMinMaxOperatorCode, nodeIndex))
+ {
+ return kTfLiteError;
+ }
+
+ armnn::ArgMinMaxDescriptor desc;
+ auto* axisData = static_cast<int*>(TfLiteOpaqueTensorData(tfLiteAxisTensor));
+ // Get the axis value from the input tensor
+ switch (TfLiteOpaqueTensorType(tfLiteAxisTensor))
+ {
+ case kTfLiteInt32:
+ case kTfLiteInt64:
+ desc.m_Axis = axisData[0];
+ break;
+ default:
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Axis value data type is not supported in operator #%d node #%d: ",
+ argMinMaxOperatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+
+ // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64.
+ if (argMinMaxOperatorCode == kTfLiteBuiltinArgMax)
+ {
+ desc.m_Function = armnn::ArgMinMaxFunction::Max;
+ auto* argMaxParameters = reinterpret_cast<TfLiteArgMaxParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
+ if (argMaxParameters->output_type != kTfLiteInt32 && argMaxParameters->output_type != kTfLiteInt64)
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: output_type data type is not supported in operator #%d node #%d: ",
+ argMinMaxOperatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+ }
+ else
+ {
+ desc.m_Function = armnn::ArgMinMaxFunction::Min;
+ auto* argMinParameters = reinterpret_cast<TfLiteArgMinParams*>(TfLiteOpaqueNodeGetBuiltinData(tfLiteNode));
+ if (argMinParameters->output_type != kTfLiteInt32 && argMinParameters->output_type != kTfLiteInt64)
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: output_type data type is not supported in operator #%d node #%d: ",
+ argMinMaxOperatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+ }
+
+ bool isSupported = false;
+ armnn::BackendId setBackend;
+ auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
+ {
+ FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("ARGMINMAX",
+ tfLiteContext,
+ IsArgMinMaxSupported,
+ delegateData.m_Backends,
+ isSupported,
+ setBackend,
+ inputTensorInfo,
+ outInfo,
+ desc);
+ };
+
+ if (!delegateData.m_Network)
+ {
+ validateFunc(outputTensorInfo, isSupported);
+ return isSupported ? kTfLiteOk : kTfLiteError;
+ }
+
+ // Add an ArgMinMax layer
+ armnn::IConnectableLayer* layer = delegateData.m_Network->AddArgMinMaxLayer(desc);
+ layer->SetBackendId(setBackend);
+ ARMNN_ASSERT(layer != nullptr);
+
+ armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
+ outputSlot.SetTensorInfo(outputTensorInfo);
+
+ // try to connect the Constant Inputs if there are any
+ if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
+ {
+ return kTfLiteError;
+ }
+
+ // Connect
+ return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
+}
+
+} \ No newline at end of file
diff --git a/delegate/opaque/src/armnn_delegate.cpp b/delegate/opaque/src/armnn_delegate.cpp
index 7f3d8cf9e9..4d43b9271e 100644
--- a/delegate/opaque/src/armnn_delegate.cpp
+++ b/delegate/opaque/src/armnn_delegate.cpp
@@ -622,6 +622,18 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData,
{
switch (TfLiteRegistrationExternalGetBuiltInCode(tfLiteRegistration))
{
+ case kTfLiteBuiltinArgMax:
+ return VisitArgMinMaxOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ kTfLiteBuiltinArgMax);
+ case kTfLiteBuiltinArgMin:
+ return VisitArgMinMaxOperator(delegateData,
+ tfLiteContext,
+ tfLiteNode,
+ nodeIndex,
+ kTfLiteBuiltinArgMin);
case kTfLiteBuiltinBatchToSpaceNd:
return VisitBatchToSpaceNdOperator(delegateData,
tfLiteContext,