aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/BackendRegistry.cpp
blob: ff63c8236a1c062f6ac5f8487457f18a7402feb4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <armnn/BackendRegistry.hpp>
#include <armnn/Exceptions.hpp>
#include <ProfilingService.hpp>

namespace armnn
{

BackendRegistry& BackendRegistryInstance()
{
    static BackendRegistry instance;
    return instance;
}

void BackendRegistry::Register(const BackendId& id, BackendRegistry::FactoryFunction factory)
{
    if (m_Factories.find(id) != m_Factories.end())
    {
        throw InvalidArgumentException(
            std::string(id) + " already registered as IBackend factory",
            CHECK_LOCATION());
    }
    m_Factories[id] = factory;

    if (m_ProfilingService.has_value())
    {
        if (m_ProfilingService.has_value() && m_ProfilingService.value().IsProfilingEnabled())
        {
            m_ProfilingService.value().IncrementCounterValue(armnn::profiling::REGISTERED_BACKENDS);
        }
    }

}

void BackendRegistry::Deregister(const BackendId& id)
{
    m_Factories.erase(id);

    if (m_ProfilingService.has_value() && m_ProfilingService.value().IsProfilingEnabled())
    {
        m_ProfilingService.value().IncrementCounterValue(armnn::profiling::UNREGISTERED_BACKENDS);
    }
}

bool BackendRegistry::IsBackendRegistered(const BackendId& id) const
{
    return (m_Factories.find(id) != m_Factories.end());
}

BackendRegistry::FactoryFunction BackendRegistry::GetFactory(const BackendId& id) const
{
    auto it = m_Factories.find(id);
    if (it == m_Factories.end())
    {
        throw InvalidArgumentException(
            std::string(id) + " has no IBackend factory registered",
            CHECK_LOCATION());
    }

    return it->second;
}

size_t BackendRegistry::Size() const
{
    return m_Factories.size();
}

BackendIdSet BackendRegistry::GetBackendIds() const
{
    BackendIdSet result;
    for (const auto& it : m_Factories)
    {
        result.insert(it.first);
    }
    return result;
}

std::string BackendRegistry::GetBackendIdsAsString() const
{
    static const std::string delimitator = ", ";

    std::stringstream output;
    for (auto& backendId : GetBackendIds())
    {
        if (output.tellp() != std::streampos(0))
        {
            output << delimitator;
        }
        output << backendId;
    }

    return output.str();
}

void BackendRegistry::Swap(BackendRegistry& instance, BackendRegistry::FactoryStorage& other)
{
    std::swap(instance.m_Factories, other);
}

void BackendRegistry::SetProfilingService(armnn::Optional<profiling::ProfilingService&> profilingService)
{
    m_ProfilingService = profilingService;
}


} // namespace armnn