aboutsummaryrefslogtreecommitdiff
path: root/delegate/opaque/include/armnn_delegate.hpp
blob: e08f466f0958d2fcdca97b7e1faa533401eeb1ea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//
// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <DelegateOptions.hpp>

#include <tensorflow/lite/c/c_api_opaque.h>
#include <tensorflow/lite/core/experimental/acceleration/configuration/c/stable_delegate.h>

namespace armnnOpaqueDelegate
{

struct DelegateData
{
    DelegateData(const std::vector<armnn::BackendId>& backends)
    : m_Backends(backends)
    , m_Network(nullptr, nullptr)
    {}

    const std::vector<armnn::BackendId>       m_Backends;
    armnn::INetworkPtr                        m_Network;
    std::vector<armnn::IOutputSlot*>          m_OutputSlotForNode;
};

/// Forward declaration for functions initializing the ArmNN Delegate
::armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault();

TfLiteOpaqueDelegate* TfLiteArmnnOpaqueDelegateCreate(const void* settings);

void TfLiteArmnnOpaqueDelegateDelete(TfLiteOpaqueDelegate* tfLiteDelegate);

TfLiteStatus DoPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueDelegate* delegate, void* data);

/// ArmNN Opaque Delegate
class ArmnnOpaqueDelegate
{
    friend class ArmnnSubgraph;
public:
    explicit ArmnnOpaqueDelegate(armnnDelegate::DelegateOptions options);

    TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteOpaqueContext* context);

    TfLiteOpaqueDelegateBuilder* GetDelegateBuilder() { return &m_Builder; }

    /// Retrieve version in X.Y.Z form
    static const std::string GetVersion();

private:
    /**
     * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates.
     */
    armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
    {
        static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
        /// Instantiated on first use.
        return instance.get();
    }

    TfLiteOpaqueDelegateBuilder m_Builder =
    {
            reinterpret_cast<void*>(this),  // .data_
            DoPrepare,                      // .Prepare
            nullptr,                        // .CopyFromBufferHandle
            nullptr,                        // .CopyToBufferHandle
            nullptr,                        // .FreeBufferHandle
            kTfLiteDelegateFlagsNone,       // .flags
    };

    /// ArmNN Runtime pointer
    armnn::IRuntime* m_Runtime;
    /// ArmNN Delegate Options
    armnnDelegate::DelegateOptions m_Options;
};

static int TfLiteArmnnOpaqueDelegateErrno(TfLiteOpaqueDelegate* delegate) { return 0; }

/// In order for the delegate to be loaded by TfLite
const TfLiteOpaqueDelegatePlugin* GetArmnnDelegatePluginApi();

extern const TfLiteStableDelegate TFL_TheStableDelegate;

/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
class ArmnnSubgraph
{
public:
    static ArmnnSubgraph* Create(TfLiteOpaqueContext* tfLiteContext,
                                 const TfLiteOpaqueDelegateParams* parameters,
                                 const ArmnnOpaqueDelegate* delegate);

    TfLiteStatus Prepare(TfLiteOpaqueContext* tfLiteContext);

    TfLiteStatus Invoke(TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode);

    static TfLiteStatus VisitNode(DelegateData& delegateData,
                                  TfLiteOpaqueContext* tfLiteContext,
                                  TfLiteRegistrationExternal* tfLiteRegistration,
                                  TfLiteOpaqueNode* tfLiteNode,
                                  int nodeIndex);
private:
    ArmnnSubgraph(armnn::NetworkId networkId,
                  armnn::IRuntime* runtime,
                  std::vector<armnn::BindingPointInfo>& inputBindings,
                  std::vector<armnn::BindingPointInfo>& outputBindings)
    : m_NetworkId(networkId)
    , m_Runtime(runtime)
    , m_InputBindings(inputBindings)
    , m_OutputBindings(outputBindings)
    {}
    static TfLiteStatus AddInputLayer(DelegateData& delegateData,
                                      TfLiteOpaqueContext* tfLiteContext,
                                      const TfLiteIntArray* inputs,
                                      std::vector<armnn::BindingPointInfo>& inputBindings);
    static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
                                       TfLiteOpaqueContext* tfLiteContext,
                                       const TfLiteIntArray* outputs,
                                       std::vector<armnn::BindingPointInfo>& outputBindings);
    /// The Network Id
    armnn::NetworkId m_NetworkId;
    /// ArmNN Runtime
    armnn::IRuntime* m_Runtime;
    /// Binding information for inputs and outputs
    std::vector<armnn::BindingPointInfo> m_InputBindings;
    std::vector<armnn::BindingPointInfo> m_OutputBindings;
};

} // armnnOpaqueDelegate namespace