aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/armnn_delegate.cpp
blob: f8a8aca13900b787f9b1f1b2a3ffc50c9399fc31 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <armnn_delegate.hpp>
#include <algorithm>

namespace armnnDelegate
{

Delegate::Delegate(armnnDelegate::DelegateOptions options)
  : m_Runtime(nullptr, nullptr),
    m_Options(std::move(options))
{
    // Create ArmNN Runtime
    armnn::IRuntime::CreationOptions runtimeOptions;
    m_Runtime = armnn::IRuntime::Create(runtimeOptions);

    std::vector<armnn::BackendId> backends;

    if (m_Runtime)
    {
        const armnn::BackendIdSet supportedDevices = m_Runtime->GetDeviceSpec().GetSupportedBackends();
        for (auto& backend : m_Options.GetBackends())
        {
            if (std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) == supportedDevices.cend())
            {
                TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
                    "TfLiteArmnnDelegate: Requested unknown backend %s", backend.Get().c_str());
            }
            else
            {
                backends.push_back(backend);
            }
        }
    }

    if (backends.empty())
    {
        // No known backend specified
        throw armnn::InvalidArgumentException("TfLiteArmnnDelegate: No known backend specified.");
    }
    m_Options.SetBackends(backends);

    TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnDelegate: Created TfLite ArmNN delegate.");
}

TfLiteIntArray* Delegate::CollectOperatorsToDelegate(TfLiteContext* tfLiteContext)
{
    TfLiteIntArray* executionPlan = nullptr;
    if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
    {
        TF_LITE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnDelegate: Unable to get graph execution plan.");
        return nullptr;
    }

    // Null INetworkPtr
    armnn::INetworkPtr nullNetworkPtr(nullptr, nullptr);

    TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size);
    nodesToDelegate->size = 0;
    for (int i = 0; i < executionPlan->size; ++i)
    {
        const int nodeIndex = executionPlan->data[i];

        // If TfLite nodes can be delegated to ArmNN
        TfLiteNode* tfLiteNode = nullptr;
        TfLiteRegistration* tfLiteRegistration = nullptr;
        if (tfLiteContext->GetNodeAndRegistration(
            tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
        {
            TF_LITE_KERNEL_LOG(tfLiteContext,
                               "TfLiteArmnnDelegate: Unable to get node and registration for node %d.",
                               nodeIndex);
            continue;
        }

        if (ArmnnSubgraph::VisitNode(
            nullNetworkPtr, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
        {
            // node is not supported by ArmNN
            continue;
        }

        nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex;
    }

    std::sort(&nodesToDelegate->data[0],
              &nodesToDelegate->data[nodesToDelegate->size]);

    return nodesToDelegate;
}

TfLiteDelegate* Delegate::GetDelegate()
{
    return &m_Delegate;
}

ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext,
                                     const TfLiteDelegateParams* parameters,
                                     const Delegate* delegate)
{
    TfLiteIntArray* executionPlan;
    if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
    {
        return nullptr;
    }

    // Construct ArmNN network
    using NetworkOptions = std::vector<armnn::BackendOptions>;
    armnn::NetworkOptions networkOptions = {};
    armnn::NetworkId networkId;
    armnn::INetworkPtr network = armnn::INetwork::Create(networkOptions);

    // Parse TfLite delegate nodes to ArmNN nodes
    for (int i = 0; i < parameters->nodes_to_replace->size; ++i)
    {
        const int nodeIndex = parameters->nodes_to_replace->data[i];

        TfLiteNode* tfLiteNode = nullptr;
        TfLiteRegistration* tfLiteRegistration = nullptr;
        if (tfLiteContext->GetNodeAndRegistration(
            tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
        {
            throw armnn::Exception("TfLiteArmnnDelegate: Unable to get node registration: " + nodeIndex);
        }

        if (VisitNode(network, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
        {
            throw armnn::Exception("TfLiteArmnnDelegate: Unable to parse node: " + nodeIndex);
        }
    }

    // Optimise Arm NN network
    armnn::IOptimizedNetworkPtr optNet =
        armnn::Optimize(*network, delegate->m_Options.GetBackends(), delegate->m_Runtime->GetDeviceSpec());
    if (!optNet)
    {
        // Optimize Failed
        throw armnn::Exception("TfLiteArmnnDelegate: Unable to optimize the network!");
    }
    // Load graph into runtime
    delegate->m_Runtime->LoadNetwork(networkId, std::move(optNet));

    // Create a new SubGraph with networkId and runtime
    return new ArmnnSubgraph(networkId, delegate->m_Runtime.get());
}

TfLiteStatus ArmnnSubgraph::Prepare(TfLiteContext* tfLiteContext)
{
    return kTfLiteOk;
}

TfLiteStatus ArmnnSubgraph::Invoke(TfLiteContext* tfLiteContext)
{
    /// Get the Input Tensors and OutputTensors from the context
    /// Execute the network
    //m_Runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);

    return kTfLiteOk;
}

TfLiteStatus ArmnnSubgraph::VisitNode(armnn::INetworkPtr& network,
                                      TfLiteContext* tfLiteContext,
                                      TfLiteRegistration* tfLiteRegistration,
                                      TfLiteNode* tfLiteNode,
                                      int nodeIndex)
{
    /*
     * Take the node and check what operator it is and VisitXXXLayer()
     * In the VisitXXXLayer() function parse TfLite node to Arm NN Layer and add it to tho network graph
     *switch (tfLiteRegistration->builtin_code)
     * {
     *     case kTfLiteBuiltinAbs:
     *              return VisitAbsLayer(...);
     *      ...
     *      default:
     *          return kTfLiteError;
     *  }
     */
    return kTfLiteError;
}

} // armnnDelegate namespace