diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-10-23 17:14:43 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-10-27 13:51:58 +0000 |
commit | 62483bee640e7d8accf6ac77b24c6e9828841851 (patch) | |
tree | ba7025bc86819c3d787428dd16b5be73b90a4353 /delegate/include | |
parent | 3d1323ff87fa92ff9cfc74097148b97fa1784416 (diff) | |
download | armnn-62483bee640e7d8accf6ac77b24c6e9828841851.tar.gz |
IVGCVSW-5366 'Add a do nothing SubGraph class'
IVGCVSW-5373 'Implement the ABS operator in the Delegate'
* Added a Switch statement into the VisitNode() function
* Separated the Visit functions into the categorized source files
* Implemented VisitElementwiseUnary() function
* Added tests for ABS and SQRT
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: If9654d0a8d8ff7dcd6fb5cbe0dc312941772affb
Diffstat (limited to 'delegate/include')
-rw-r--r-- | delegate/include/armnn_delegate.hpp | 135 |
1 files changed, 49 insertions, 86 deletions
diff --git a/delegate/include/armnn_delegate.hpp b/delegate/include/armnn_delegate.hpp index 6136f2bebe..6f18185d7b 100644 --- a/delegate/include/armnn_delegate.hpp +++ b/delegate/include/armnn_delegate.hpp @@ -3,7 +3,8 @@ // SPDX-License-Identifier: MIT // -#pragma once +#ifndef ARMNN_TFLITE_DELEGATE +#define ARMNN_TFLITE_DELEGATE #include "DelegateOptions.hpp" @@ -15,32 +16,51 @@ namespace armnnDelegate { -TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); +struct DelegateData +{ + DelegateData(const std::vector<armnn::BackendId>& backends) + : m_Backends(backends) + , m_Network(nullptr, nullptr) + {} + + const std::vector<armnn::BackendId> m_Backends; + armnn::INetworkPtr m_Network; + std::vector<armnn::IOutputSlot*> m_OutputSlotForNode; +}; + +// Forward decleration for functions initializing the ArmNN Delegate +DelegateOptions TfLiteArmnnDelegateOptionsDefault(); + +TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options); + +void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate); -/// Delegate class +TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate); + +/// ArmNN Delegate class Delegate { friend class ArmnnSubgraph; public: explicit Delegate(armnnDelegate::DelegateOptions options); - TfLiteIntArray* CollectOperatorsToDelegate(TfLiteContext* context); + TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context); TfLiteDelegate* GetDelegate(); private: TfLiteDelegate m_Delegate = { reinterpret_cast<void*>(this), // .data_ - DelegatePrepare, // .Prepare + DoPrepare, // .Prepare nullptr, // .CopyFromBufferHandle nullptr, // .CopyToBufferHandle nullptr, // .FreeBufferHandle kTfLiteDelegateFlagsNone, // .flags }; - /// Arm NN Runtime pointer + /// ArmNN Runtime pointer armnn::IRuntimePtr m_Runtime; - /// Arm NN Delegate Options + /// ArmNN Delegate Options armnnDelegate::DelegateOptions m_Options; }; @@ -54,102 +74,45 @@ public: TfLiteStatus Prepare(TfLiteContext* tfLiteContext); - TfLiteStatus Invoke(TfLiteContext* tfLiteContext); + TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode); - static TfLiteStatus VisitNode(armnn::INetworkPtr& network, + static TfLiteStatus VisitNode(DelegateData& delegateData, TfLiteContext* tfLiteContext, TfLiteRegistration* tfLiteRegistration, TfLiteNode* tfLiteNode, int nodeIndex); private: - ArmnnSubgraph(armnn::NetworkId networkId, armnn::IRuntime* runtime) - : m_NetworkId(networkId), m_Runtime(runtime) + ArmnnSubgraph(armnn::NetworkId networkId, + armnn::IRuntime* runtime, + std::vector<armnn::BindingPointInfo>& inputBindings, + std::vector<armnn::BindingPointInfo>& outputBindings) + : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings) {} + static TfLiteStatus AddInputLayer(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const TfLiteIntArray* inputs, + std::vector<armnn::BindingPointInfo>& inputBindings); + + static TfLiteStatus AddOutputLayer(DelegateData& delegateData, + TfLiteContext* tfLiteContext, + const TfLiteIntArray* outputs, + std::vector<armnn::BindingPointInfo>& outputBindings); + + /// The Network Id armnn::NetworkId m_NetworkId; /// ArmNN Rumtime armnn::IRuntime* m_Runtime; -}; - -void* ArmnnSubgraphInit(TfLiteContext* tfLiteContext, const char* buffer, size_t length) -{ - const TfLiteDelegateParams* parameters = reinterpret_cast<const TfLiteDelegateParams*>(buffer); - return static_cast<void*>(ArmnnSubgraph::Create( - tfLiteContext, parameters, static_cast<::armnnDelegate::Delegate*>(parameters->delegate->data_))); -} + // Binding information for inputs and outputs + std::vector<armnn::BindingPointInfo> m_InputBindings; + std::vector<armnn::BindingPointInfo> m_OutputBindings; -TfLiteStatus ArmnnSubgraphPrepare(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode) -{ - if (tfLiteNode->user_data == nullptr) - { - return kTfLiteError; - } - - return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Prepare(tfLiteContext); -} - -TfLiteStatus ArmnnSubgraphInvoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode) -{ - if (tfLiteNode->user_data == nullptr) - { - return kTfLiteError; - } - - return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Invoke(tfLiteContext); -} - -void ArmnnSubgraphFree(TfLiteContext* tfLiteContext, void* buffer) -{ - if (buffer != nullptr) - { - delete static_cast<ArmnnSubgraph*>(buffer); - } -} - -const TfLiteRegistration armnnSubgraphRegistration = { - ArmnnSubgraphInit, // .init - ArmnnSubgraphFree, // .free - ArmnnSubgraphPrepare, // .prepare - ArmnnSubgraphInvoke, // .invoke - nullptr, // .profiling_string - 0, // .builtin_code - "TfLiteArmnnDelegate", // .custom_name - 1, // .version }; -TfLiteStatus DelegatePrepare(TfLiteContext* tfLiteContext, TfLiteDelegate* tfLiteDelegate) -{ - TfLiteIntArray* supportedOperators = - static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_)->CollectOperatorsToDelegate(tfLiteContext); - - const TfLiteStatus status = - tfLiteContext->ReplaceNodeSubsetsWithDelegateKernels( - tfLiteContext, armnnSubgraphRegistration, supportedOperators, tfLiteDelegate); - TfLiteIntArrayFree(supportedOperators); - - return status; -} - } // armnnDelegate namespace -armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault() { - armnnDelegate::DelegateOptions options(armnn::Compute::CpuRef); - return options; -} +#endif // ARMNN_TFLITE_DELEGATE -TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options) -{ - auto* armnnDelegate = new ::armnnDelegate::Delegate(options); - return armnnDelegate->GetDelegate(); -} - -void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate) -{ - if (tfLiteDelegate != nullptr) - { - delete static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_); - } -}
\ No newline at end of file |