diff options
Diffstat (limited to 'delegate/include')
-rw-r--r-- | delegate/include/armnn_delegate.hpp | 135 |
1 files changed, 49 insertions, 86 deletions
diff --git a/delegate/include/armnn_delegate.hpp b/delegate/include/armnn_delegate.hpp index 6136f2bebe..6f18185d7b 100644 --- a/delegate/include/armnn_delegate.hpp +++ b/delegate/include/armnn_delegate.hpp @@ -3,7 +3,8 @@ // SPDX-License-Identifier: MIT // -#pragma once +#ifndef ARMNN_TFLITE_DELEGATE +#define ARMNN_TFLITE_DELEGATE #include "DelegateOptions.hpp" @@ -15,32 +16,51 @@ namespace armnnDelegate { -TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); +struct DelegateData +{ + DelegateData(const std::vector<armnn::BackendId>& backends) + : m_Backends(backends) + , m_Network(nullptr, nullptr) + {} + + const std::vector<armnn::BackendId> m_Backends; + armnn::INetworkPtr m_Network; + std::vector<armnn::IOutputSlot*> m_OutputSlotForNode; +}; + +// Forward decleration for functions initializing the ArmNN Delegate +DelegateOptions TfLiteArmnnDelegateOptionsDefault(); + +TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options); + +void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate); -/// Delegate class +TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate); + +/// ArmNN Delegate class Delegate { friend class ArmnnSubgraph; public: explicit Delegate(armnnDelegate::DelegateOptions options); - TfLiteIntArray* CollectOperatorsToDelegate(TfLiteContext* context); + TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context); TfLiteDelegate* GetDelegate(); private: TfLiteDelegate m_Delegate = { reinterpret_cast<void*>(this), // .data_ - DelegatePrepare, // .Prepare + DoPrepare, // .Prepare nullptr, // .CopyFromBufferHandle nullptr, // .CopyToBufferHandle nullptr, // .FreeBufferHandle kTfLiteDelegateFlagsNone, // .flags }; - /// Arm NN Runtime pointer + /// ArmNN Runtime pointer armnn::IRuntimePtr m_Runtime; - /// Arm NN Delegate Options + /// ArmNN Delegate Options armnnDelegate::DelegateOptions m_Options; }; @@ -54,102 +74,45 @@ public: TfLiteStatus Prepare(TfLiteContext* tfLiteContext); - TfLiteStatus Invoke(TfLiteContext* tfLiteContext); + TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode); - static TfLiteStatus VisitNode(armnn::INetworkPtr& network, + static TfLiteStatus VisitNode(DelegateData& delegateData, TfLiteContext* tfLiteContext, TfLiteRegistration* tfLiteRegistration, TfLiteNode* tfLiteNode, int nodeIndex); private: - ArmnnSubgraph(armnn::NetworkId networkId, armnn::IRuntime* runtime) - : m_NetworkId(networkId), m_Runtime(runtime) + ArmnnSubgraph(armnn::NetworkId networkId, + armnn::IRuntime* runtime, + std::vector<armnn::BindingPointInfo>& inputBindings, + std::vector<armnn::BindingPointInfo>& outputBindings) + : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings) {} + static TfLiteStatus AddInputLayer(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const TfLiteIntArray* inputs, + std::vector<armnn::BindingPointInfo>& inputBindings); + + static TfLiteStatus AddOutputLayer(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const TfLiteIntArray* outputs, + std::vector<armnn::BindingPointInfo>& outputBindings); + + /// The Network Id armnn::NetworkId m_NetworkId; /// ArmNN Rumtime armnn::IRuntime* m_Runtime; -}; - -void* ArmnnSubgraphInit(TfLiteContext* tfLiteContext, const char* buffer, size_t length) -{ - const TfLiteDelegateParams* parameters = reinterpret_cast<const TfLiteDelegateParams*>(buffer); - return static_cast<void*>(ArmnnSubgraph::Create( - tfLiteContext, parameters, static_cast<::armnnDelegate::Delegate*>(parameters->delegate->data_))); -} + // Binding information for inputs and outputs + std::vector<armnn::BindingPointInfo> m_InputBindings; + std::vector<armnn::BindingPointInfo> m_OutputBindings; -TfLiteStatus ArmnnSubgraphPrepare(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode) -{ - if (tfLiteNode->user_data == nullptr) - { - return kTfLiteError; - } - - return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Prepare(tfLiteContext); -} - -TfLiteStatus ArmnnSubgraphInvoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode) -{ - if (tfLiteNode->user_data == nullptr) - { - return kTfLiteError; - } - - return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Invoke(tfLiteContext); -} - -void ArmnnSubgraphFree(TfLiteContext* tfLiteContext, void* buffer) -{ - if (buffer != nullptr) - { - delete static_cast<ArmnnSubgraph*>(buffer); - } -} - -const TfLiteRegistration armnnSubgraphRegistration = { - ArmnnSubgraphInit, // .init - ArmnnSubgraphFree, // .free - ArmnnSubgraphPrepare, // .prepare - ArmnnSubgraphInvoke, // .invoke - nullptr, // .profiling_string - 0, // .builtin_code - "TfLiteArmnnDelegate", // .custom_name - 1, // .version }; -TfLiteStatus DelegatePrepare(TfLiteContext* tfLiteContext, TfLiteDelegate* tfLiteDelegate) -{ - TfLiteIntArray* supportedOperators = - static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_)->CollectOperatorsToDelegate(tfLiteContext); - - const TfLiteStatus status = - tfLiteContext->ReplaceNodeSubsetsWithDelegateKernels( - tfLiteContext, armnnSubgraphRegistration, supportedOperators, tfLiteDelegate); - TfLiteIntArrayFree(supportedOperators); - - return status; -} - } // armnnDelegate namespace -armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault() { - armnnDelegate::DelegateOptions options(armnn::Compute::CpuRef); - return options; -} +#endif // ARMNN_TFLITE_DELEGATE -TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options) -{ - auto* armnnDelegate = new ::armnnDelegate::Delegate(options); - return armnnDelegate->GetDelegate(); -} - -void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate) -{ - if (tfLiteDelegate != nullptr) - { - delete static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_); - } -}
\ No newline at end of file |