aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2023-04-03 16:32:57 +0100
committerryan.oshea3 <ryan.oshea3@arm.com>2023-04-05 21:37:14 +0000
commit54cf011c89ed7512853ec4472a6fd52fb8f9495f (patch)
tree080b931632df843d450c0b6ad25c2c78ddaa86df
parentebe392df1635790bf21714549adb97f2f75559e1 (diff)
downloadarmnn-54cf011c89ed7512853ec4472a6fd52fb8f9495f.tar.gz
IVGCVSW-7559 Implement DoPrepare with registration
* Added ArmnnOpaqueDelegate::IdentifyOperatorsToDelegate implementation. Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: I0b65847358d339a15fc3f729f89deb9b86da0c66
-rw-r--r--delegate/opaque/src/armnn_delegate.cpp175
1 files changed, 175 insertions, 0 deletions
diff --git a/delegate/opaque/src/armnn_delegate.cpp b/delegate/opaque/src/armnn_delegate.cpp
index 6983873697..2fbfda3628 100644
--- a/delegate/opaque/src/armnn_delegate.cpp
+++ b/delegate/opaque/src/armnn_delegate.cpp
@@ -98,6 +98,93 @@ ArmnnOpaqueDelegate::ArmnnOpaqueDelegate(armnnDelegate::DelegateOptions options)
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnOpaqueDelegate: Created TfLite ArmNN delegate.");
}
+TfLiteStatus DoPrepare(TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueDelegate* tfLiteDelegate)
+{
+ TfLiteIntArray* supportedOperators =
+ static_cast<::armnnOpaqueDelegate::ArmnnOpaqueDelegate*>
+ (tfLiteDelegate->data_)->IdentifyOperatorsToDelegate(tfLiteContext);
+ if(supportedOperators == nullptr)
+ {
+ return kTfLiteError;
+ }
+
+ // ArmNN Opaque Delegate Registration
+ TfLiteRegistrationExternal* kernelRegistration =
+ TfLiteRegistrationExternalCreate(kTfLiteBuiltinDelegate, "TfLiteArmNNOpaqueDelegate", /*version=*/1);
+ if(kernelRegistration == nullptr)
+ {
+ return kTfLiteError;
+ }
+
+ TfLiteRegistrationExternalSetInit(
+ kernelRegistration,
+ [](TfLiteOpaqueContext* tfLiteContext, const char* buffer, size_t length) -> void*
+ {
+ armnn::IgnoreUnused(length);
+ const TfLiteOpaqueDelegateParams* parameters =
+ reinterpret_cast<const TfLiteOpaqueDelegateParams*>(buffer);
+ if(parameters == nullptr)
+ {
+ TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Unable to get parameters.");
+ return nullptr;
+ }
+
+ return static_cast<void*>(
+ ArmnnSubgraph::Create(tfLiteContext,
+ parameters,
+ static_cast<::armnnOpaqueDelegate::ArmnnOpaqueDelegate*>(
+ parameters->delegate->data_)));
+ }
+ );
+
+ TfLiteRegistrationExternalSetFree(
+ kernelRegistration,
+ [](TfLiteOpaqueContext* tfLiteContext, void* buffer) -> void
+ {
+ armnn::IgnoreUnused(tfLiteContext);
+ if (buffer != nullptr)
+ {
+ delete static_cast<ArmnnSubgraph*>(buffer);
+ }
+ }
+ );
+
+ TfLiteRegistrationExternalSetPrepare(
+ kernelRegistration,
+ [](TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode) -> TfLiteStatus
+ {
+ void* userData = TfLiteOpaqueNodeGetUserData(tfLiteNode);
+ if (userData == nullptr)
+ {
+ return kTfLiteError;
+ }
+ return static_cast<ArmnnSubgraph*>(userData)->Prepare(tfLiteContext);
+ }
+ );
+
+ TfLiteRegistrationExternalSetInvoke(
+ kernelRegistration,
+ [](TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode) -> TfLiteStatus
+ {
+ void* userData = TfLiteOpaqueNodeGetUserData(tfLiteNode);
+ if (userData == nullptr)
+ {
+ return kTfLiteError;
+ }
+
+ return static_cast<ArmnnSubgraph*>(userData)->Invoke(tfLiteContext, tfLiteNode);
+ }
+ );
+
+ const TfLiteStatus status =
+ TfLiteOpaqueContextReplaceNodeSubsetsWithDelegateKernels(
+ tfLiteContext, kernelRegistration, supportedOperators, tfLiteDelegate);
+
+ TfLiteIntArrayFree(supportedOperators);
+ return status;
+}
+
TfLiteOpaqueDelegate* TfLiteArmnnOpaqueDelegateCreate(const void* settings)
{
// This method will always create Opaque Delegate with default settings until
@@ -134,6 +221,94 @@ const std::string ArmnnOpaqueDelegate::GetVersion() {
return OPAQUE_DELEGATE_VERSION;
}
+TfLiteIntArray* ArmnnOpaqueDelegate::IdentifyOperatorsToDelegate(TfLiteOpaqueContext* tfLiteContext)
+{
+ TfLiteIntArray* executionPlan = nullptr;
+ if (TfLiteOpaqueContextGetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
+ {
+ TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnOpaqueDelegate: Unable to get graph execution plan.");
+ return nullptr;
+ }
+
+ // Delegate data with null network
+ DelegateData delegateData(m_Options.GetBackends());
+
+ TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size);
+ if (nodesToDelegate == nullptr)
+ {
+ TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Unable to create int array from execution plan.");
+ return nullptr;
+ }
+ nodesToDelegate->size = 0;
+
+ std::set<int32_t> unsupportedOperators;
+
+ for (int i = 0; i < executionPlan->size; ++i)
+ {
+ const int nodeIndex = executionPlan->data[i];
+
+ // If TfLiteOpaqueNodes can be delegated to ArmNN
+ TfLiteOpaqueNode* tfLiteNode = nullptr;
+ TfLiteRegistrationExternal* tfLiteRegistration = nullptr;
+
+ if (TfLiteOpaqueContextGetNodeAndRegistration(
+ tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
+ {
+ TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Unable to get node and registration for node %d.",
+ nodeIndex);
+ continue;
+ }
+
+ TfLiteStatus visitStatus;
+ try
+ {
+ visitStatus = ArmnnSubgraph::VisitNode(
+ delegateData, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex);
+ }
+ catch(std::exception& ex)
+ {
+ ARMNN_LOG(error) << "ArmNN Failed to visit node with error: " << ex.what();
+ visitStatus = kTfLiteError;
+ }
+
+ if (visitStatus != kTfLiteOk)
+ {
+ // node is not supported by ArmNN
+ unsupportedOperators.insert(TfLiteRegistrationExternalGetBuiltInCode(tfLiteRegistration));
+ continue;
+ }
+
+ nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex;
+ }
+
+ for (std::set<int32_t>::iterator it=unsupportedOperators.begin(); it!=unsupportedOperators.end(); ++it)
+ {
+ TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
+ "Operator %s [%d] is not supported by armnn_opaque_delegate.",
+ tflite::EnumNameBuiltinOperator(tflite::BuiltinOperator(*it)),
+ *it);
+ }
+
+ if (!unsupportedOperators.empty() && m_Options.TfLiteRuntimeFallbackDisabled())
+ {
+ std::stringstream exMessage;
+ exMessage << "TfLiteArmnnOpaqueDelegate: There are unsupported operators in the model. ";
+ exMessage << "Not falling back to TfLite Runtime as fallback is disabled. ";
+ exMessage << "This should only be disabled under test conditions.";
+ throw armnn::Exception(exMessage.str());
+ }
+ if (nodesToDelegate->size == 0)
+ {
+ ARMNN_LOG(info) << "No operators in this model are supported by the Arm NN TfLite delegate." <<
+ " The model will be executed entirely by TfLite runtime.";
+ }
+
+ std::sort(&nodesToDelegate->data[0], &nodesToDelegate->data[nodesToDelegate->size]);
+ return nodesToDelegate;
+}
+
TfLiteStatus ArmnnSubgraph::AddInputLayer(DelegateData& delegateData,
TfLiteOpaqueContext* tfLiteContext,
const TfLiteIntArray* inputs,