aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
blob: e0d6a9add0fd4db85e37ece734bb426e57b2881b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//
#include "Runtime.hpp"

#include "armnn/Version.hpp"

#ifdef ARMCOMPUTECL_ENABLED
#include <arm_compute/core/CL/OpenCL.h>
#include <arm_compute/core/CL/CLKernelLibrary.h>
#include <arm_compute/runtime/CL/CLScheduler.h>
#endif

#include <boost/log/trivial.hpp>
#include <boost/polymorphic_cast.hpp>

using namespace armnn;
using namespace std;

namespace armnn
{

IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
{
    return new Runtime(options);
}

IRuntimePtr IRuntime::Create(const CreationOptions& options)
{
    return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
}

void IRuntime::Destroy(IRuntime* runtime)
{
    delete boost::polymorphic_downcast<Runtime*>(runtime);
}

int Runtime::GenerateNetworkId()
{
    return m_NetworkIdCounter++;
}

Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
{
    IOptimizedNetwork* rawNetwork = inNetwork.release();
    unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
        std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
        m_WorkloadFactories);

    if (!loadedNetwork)
    {
        return Status::Failure;
    }

    networkIdOut = GenerateNetworkId();

    // store the network
    m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);

    return Status::Success;
}

Status Runtime::UnloadNetwork(NetworkId networkId)
{
#ifdef ARMCOMPUTECL_ENABLED
    if (arm_compute::CLScheduler::get().context()() != NULL)
    {
        arm_compute::CLScheduler::get().sync();
    }
#endif
    if (m_LoadedNetworks.erase(networkId) == 0)
    {
        BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
        return Status::Failure;
    }
#ifdef ARMCOMPUTECL_ENABLED
    if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
    {
        m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime();
    }
#endif
    BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
    return Status::Success;
}

Runtime::Runtime(const CreationOptions& options)
: m_NetworkIdCounter(0)
{
    BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
    BOOST_LOG_TRIVIAL(info) << "Using compute device: " << options.m_DefaultComputeDevice << "\n";
    m_DeviceSpec.DefaultComputeDevice = options.m_DefaultComputeDevice;

   // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
   // operation workloads, unless the default compute device is precisely the reference backend.
    m_WorkloadFactories.m_CpuRef = make_shared<RefWorkloadFactory>(
        options.m_DefaultComputeDevice == Compute::CpuRef ? true : options.m_UseCpuRefAsFallback);
    m_WorkloadFactories.m_CpuAcc = make_shared<NeonWorkloadFactory>();
    m_WorkloadFactories.m_GpuAcc = make_shared<ClWorkloadFactory>(options.m_ClTunedParameters);

    if (options.m_DefaultComputeDevice == Compute::GpuAcc)
    {
        m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime();
    }
}

Runtime::~Runtime()
{
    std::vector<int> networkIDs;
    std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
                   std::back_inserter(networkIDs),
                   [](const auto &pair) { return pair.first; });

    for (auto networkID : networkIDs)
    {
        UnloadNetwork(networkID);
    }
}

TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
    LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
    return net->GetInputTensorInfo(layerId);
}

TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
    const LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
    return net->GetOutputTensorInfo(layerId);
}

Status Runtime::EnqueueWorkload(NetworkId networkId,
                                     const InputTensors& inputTensors,
                                     const OutputTensors& outputTensors)
{
    LoadedNetwork* loadedNetwork = m_LoadedNetworks.at(networkId).get();
    return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors, m_WorkloadFactories);
}

}