aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
blob: ea6d19bd31ce70b50129a8a91d62a25c9f2500b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//
#include "Runtime.hpp"

#include "armnn/Version.hpp"

#ifdef ARMCOMPUTECL_ENABLED
#include <arm_compute/core/CL/OpenCL.h>
#include <arm_compute/core/CL/CLKernelLibrary.h>
#endif

#include <boost/log/trivial.hpp>
#include <boost/polymorphic_cast.hpp>

using namespace armnn;
using namespace std;

namespace armnn
{

IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
{
    return new Runtime(options);
}

IRuntimePtr IRuntime::Create(const CreationOptions& options)
{
    return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
}

void IRuntime::Destroy(IRuntime* runtime)
{
    delete boost::polymorphic_downcast<Runtime*>(runtime);
}

int Runtime::GenerateNetworkId()
{
    return m_NetworkIdCounter++;
}

Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
{
    IOptimizedNetwork* rawNetwork = inNetwork.release();
    unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
        std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
        m_WorkloadFactories);

    if (!loadedNetwork)
    {
        return Status::Failure;
    }

    networkIdOut = GenerateNetworkId();

    // store the network
    m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);

    return Status::Success;

}

Status Runtime::UnloadNetwork(NetworkId networkId)
{
    if (m_LoadedNetworks.erase(networkId) == 0)
    {
        BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
        return Status::Failure;
    }
#ifdef ARMCOMPUTECL_ENABLED
    arm_compute::CLKernelLibrary::get().clear_programs_cache();
#endif
    BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
    return Status::Success;
}

Runtime::Runtime(const CreationOptions& options)
: m_NetworkIdCounter(0)
{
    BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
    BOOST_LOG_TRIVIAL(info) << "Using compute device: " << options.m_DefaultComputeDevice << "\n";
    m_DeviceSpec.DefaultComputeDevice = options.m_DefaultComputeDevice;

   // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
   // operation workloads, unless the default compute device is precisely the reference backend.
    m_WorkloadFactories.m_CpuRef = make_shared<RefWorkloadFactory>(
        options.m_DefaultComputeDevice == Compute::CpuRef ? true : options.m_UseCpuRefAsFallback);
    m_WorkloadFactories.m_CpuAcc = make_shared<NeonWorkloadFactory>();
    m_WorkloadFactories.m_GpuAcc = make_shared<ClWorkloadFactory>();

    if (options.m_DefaultComputeDevice == Compute::GpuAcc)
    {
        m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime(options.m_ClTunedParameters);
    }
}

TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
    LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
    return net->GetInputTensorInfo(layerId);
}

TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
    const LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
    return net->GetOutputTensorInfo(layerId);
}

Status Runtime::EnqueueWorkload(NetworkId networkId,
                                     const InputTensors& inputTensors,
                                     const OutputTensors& outputTensors)
{
    LoadedNetwork* loadedNetwork = m_LoadedNetworks.at(networkId).get();
    return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors, m_WorkloadFactories);
}

}