aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/BackendRegistry.cpp
blob: 80daed9896a06a0f51609943b6b2a03c9d7606e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <armnn/BackendRegistry.hpp>
#include <armnn/Exceptions.hpp>
#include <ProfilingService.hpp>

namespace armnn
{

BackendRegistry& BackendRegistryInstance()
{
    static BackendRegistry instance;
    return instance;
}

void BackendRegistry::Register(const BackendId& id, BackendRegistry::FactoryFunction factory)
{
    if (m_Factories.find(id) != m_Factories.end())
    {
        throw InvalidArgumentException(
            std::string(id) + " already registered as IBackend factory",
            CHECK_LOCATION());
    }
    m_Factories[id] = factory;

    if (m_ProfilingService.has_value())
    {
        if (m_ProfilingService.has_value() && m_ProfilingService.value().IsProfilingEnabled())
        {
            m_ProfilingService.value().IncrementCounterValue(armnn::profiling::REGISTERED_BACKENDS);
        }
    }

}

void BackendRegistry::Deregister(const BackendId& id)
{
    m_Factories.erase(id);
    DeregisterAllocator(id);

    if (m_ProfilingService.has_value() && m_ProfilingService.value().IsProfilingEnabled())
    {
        m_ProfilingService.value().IncrementCounterValue(armnn::profiling::UNREGISTERED_BACKENDS);
    }
}

bool BackendRegistry::IsBackendRegistered(const BackendId& id) const
{
    return (m_Factories.find(id) != m_Factories.end());
}

BackendRegistry::FactoryFunction BackendRegistry::GetFactory(const BackendId& id) const
{
    auto it = m_Factories.find(id);
    if (it == m_Factories.end())
    {
        throw InvalidArgumentException(
            std::string(id) + " has no IBackend factory registered",
            CHECK_LOCATION());
    }

    return it->second;
}

size_t BackendRegistry::Size() const
{
    return m_Factories.size();
}

BackendIdSet BackendRegistry::GetBackendIds() const
{
    BackendIdSet result;
    for (const auto& it : m_Factories)
    {
        result.insert(it.first);
    }
    return result;
}

std::string BackendRegistry::GetBackendIdsAsString() const
{
    static const std::string delimitator = ", ";

    std::stringstream output;
    for (auto& backendId : GetBackendIds())
    {
        if (output.tellp() != std::streampos(0))
        {
            output << delimitator;
        }
        output << backendId;
    }

    return output.str();
}

void BackendRegistry::Swap(BackendRegistry& instance, BackendRegistry::FactoryStorage& other)
{
    std::swap(instance.m_Factories, other);
}

void BackendRegistry::SetProfilingService(armnn::Optional<profiling::ProfilingService&> profilingService)
{
    m_ProfilingService = profilingService;
}

void BackendRegistry::RegisterAllocator(const BackendId& id, std::shared_ptr<ICustomAllocator> alloc)
{
    if (m_CustomMemoryAllocatorMap.find(id) != m_CustomMemoryAllocatorMap.end())
    {
        throw InvalidArgumentException(
            std::string(id) + " already has an allocator associated with it",
            CHECK_LOCATION());
    }
    m_CustomMemoryAllocatorMap[id] = alloc;
}

void BackendRegistry::DeregisterAllocator(const BackendId& id)
{
    m_CustomMemoryAllocatorMap.erase(id);
}

std::unordered_map<BackendId, std::shared_ptr<ICustomAllocator>> BackendRegistry::GetAllocators()
{
    return m_CustomMemoryAllocatorMap;
}

} // namespace armnn