// // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "DelegateOptions.hpp" #include #include #include #include namespace armnnDelegate { struct DelegateData { DelegateData(const std::vector& backends) : m_Backends(backends) , m_Network(nullptr, nullptr) {} const std::vector m_Backends; armnn::INetworkPtr m_Network; std::vector m_OutputSlotForNode; }; // Forward decleration for functions initializing the ArmNN Delegate DelegateOptions TfLiteArmnnDelegateOptionsDefault(); TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options); void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate); TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate); /// ArmNN Delegate class Delegate { friend class ArmnnSubgraph; public: explicit Delegate(armnnDelegate::DelegateOptions options); TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context); TfLiteDelegate* GetDelegate(); private: TfLiteDelegate m_Delegate = { reinterpret_cast(this), // .data_ DoPrepare, // .Prepare nullptr, // .CopyFromBufferHandle nullptr, // .CopyToBufferHandle nullptr, // .FreeBufferHandle kTfLiteDelegateFlagsNone, // .flags }; /// ArmNN Runtime pointer armnn::IRuntimePtr m_Runtime; /// ArmNN Delegate Options armnnDelegate::DelegateOptions m_Options; }; /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph class ArmnnSubgraph { public: static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext, const TfLiteDelegateParams* parameters, const Delegate* delegate); TfLiteStatus Prepare(TfLiteContext* tfLiteContext); TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode); static TfLiteStatus VisitNode(DelegateData& delegateData, TfLiteContext* tfLiteContext, TfLiteRegistration* tfLiteRegistration, TfLiteNode* tfLiteNode, int nodeIndex); private: ArmnnSubgraph(armnn::NetworkId networkId, armnn::IRuntime* runtime, std::vector& inputBindings, std::vector& outputBindings) : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings) {} static TfLiteStatus AddInputLayer(DelegateData& delegateData, TfLiteContext* tfLiteContext, const TfLiteIntArray* inputs, std::vector& inputBindings); static TfLiteStatus AddOutputLayer(DelegateData& delegateData, TfLiteContext* tfLiteContext, const TfLiteIntArray* outputs, std::vector& outputBindings); /// The Network Id armnn::NetworkId m_NetworkId; /// ArmNN Rumtime armnn::IRuntime* m_Runtime; // Binding information for inputs and outputs std::vector m_InputBindings; std::vector m_OutputBindings; }; } // armnnDelegate namespace