aboutsummaryrefslogtreecommitdiff
path: root/delegate/include/armnn_delegate.hpp
blob: adf264aabbcc24749fa06a3f874a45a8bfb6f7f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "DelegateOptions.hpp"

#include <tensorflow/lite/builtin_ops.h>
#include <tensorflow/lite/c/builtin_op_data.h>
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/minimal_logging.h>

namespace armnnDelegate
{

struct DelegateData
{
    DelegateData(const std::vector<armnn::BackendId>& backends)
        : m_Backends(backends)
        , m_Network(nullptr, nullptr)
    {}

    const std::vector<armnn::BackendId>       m_Backends;
    armnn::INetworkPtr                        m_Network;
    std::vector<armnn::IOutputSlot*>          m_OutputSlotForNode;
};

// Forward decleration for functions initializing the ArmNN Delegate
DelegateOptions TfLiteArmnnDelegateOptionsDefault();

TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);

void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);

TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);

/// ArmNN Delegate
class Delegate
{
    friend class ArmnnSubgraph;
public:
    explicit Delegate(armnnDelegate::DelegateOptions options);

    TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);

    TfLiteDelegate* GetDelegate();

private:
    TfLiteDelegate m_Delegate = {
        reinterpret_cast<void*>(this),  // .data_
        DoPrepare,                      // .Prepare
        nullptr,                        // .CopyFromBufferHandle
        nullptr,                        // .CopyToBufferHandle
        nullptr,                        // .FreeBufferHandle
        kTfLiteDelegateFlagsNone,       // .flags
    };

    /// ArmNN Runtime pointer
    armnn::IRuntimePtr m_Runtime;
    /// ArmNN Delegate Options
    armnnDelegate::DelegateOptions m_Options;
};

/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
class ArmnnSubgraph
{
public:
    static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
                                 const TfLiteDelegateParams* parameters,
                                 const Delegate* delegate);

    TfLiteStatus Prepare(TfLiteContext* tfLiteContext);

    TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);

    static TfLiteStatus VisitNode(DelegateData& delegateData,
                                  TfLiteContext* tfLiteContext,
                                  TfLiteRegistration* tfLiteRegistration,
                                  TfLiteNode* tfLiteNode,
                                  int nodeIndex);

private:
    ArmnnSubgraph(armnn::NetworkId networkId,
                  armnn::IRuntime* runtime,
                  std::vector<armnn::BindingPointInfo>& inputBindings,
                  std::vector<armnn::BindingPointInfo>& outputBindings)
        : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
    {}

    static TfLiteStatus AddInputLayer(DelegateData& delegateData,
                                      TfLiteContext* tfLiteContext,
                                      const TfLiteIntArray* inputs,
                                      std::vector<armnn::BindingPointInfo>& inputBindings);

    static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
                                       TfLiteContext* tfLiteContext,
                                       const TfLiteIntArray* outputs,
                                       std::vector<armnn::BindingPointInfo>& outputBindings);


    /// The Network Id
    armnn::NetworkId m_NetworkId;
    /// ArmNN Rumtime
    armnn::IRuntime* m_Runtime;

    // Binding information for inputs and outputs
    std::vector<armnn::BindingPointInfo> m_InputBindings;
    std::vector<armnn::BindingPointInfo> m_OutputBindings;

};

} // armnnDelegate namespace