aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/RegistryCommon.hpp
blob: 3dbfad2a668b6e04d61ac328b2f8ca854f60b215 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/BackendId.hpp>
#include <armnn/Exceptions.hpp>

#include <functional>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>

namespace armnn
{

template <typename RegisteredType>
struct RegisteredTypeName
{
    static const char * Name() { return "UNKNOWN"; }
};

template <typename RegisteredType, typename PointerType, typename ParamType>
class RegistryCommon
{
public:
    using FactoryFunction = std::function<PointerType(const ParamType&)>;

    void Register(const BackendId& id, FactoryFunction factory)
    {
        if (m_Factories.count(id) > 0)
        {
            throw InvalidArgumentException(
                std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory",
                CHECK_LOCATION());
        }

        m_Factories[id] = factory;
    }

    FactoryFunction GetFactory(const BackendId& id) const
    {
        auto it = m_Factories.find(id);
        if (it == m_Factories.end())
        {
            throw InvalidArgumentException(
                std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered",
                CHECK_LOCATION());
        }

        return it->second;
    }

    FactoryFunction GetFactory(const BackendId& id,
                               FactoryFunction defaultFactory) const
    {
        auto it = m_Factories.find(id);
        if (it == m_Factories.end())
        {
            return defaultFactory;
        }
        else
        {
            return it->second;
        }
    }

    size_t Size() const
    {
        return m_Factories.size();
    }

    BackendIdSet GetBackendIds() const
    {
        BackendIdSet result;
        for (const auto& it : m_Factories)
        {
            result.insert(it.first);
        }
        return result;
    }

    std::string GetBackendIdsAsString() const
    {
        static const std::string delimitator = ", ";

        std::stringstream output;
        for (auto& backendId : GetBackendIds())
        {
            if (output.tellp() != std::streampos(0))
            {
                output << delimitator;
            }
            output << backendId;
        }

        return output.str();
    }

    RegistryCommon() {}
    virtual ~RegistryCommon() {}

protected:
    using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;

    // For testing only
    static void Swap(RegistryCommon& instance, FactoryStorage& other)
    {
        std::swap(instance.m_Factories, other);
    }

private:
    RegistryCommon(const RegistryCommon&) = delete;
    RegistryCommon& operator=(const RegistryCommon&) = delete;

    FactoryStorage m_Factories;
};

template <typename RegistryType>
struct StaticRegistryInitializer
{
    using FactoryFunction = typename RegistryType::FactoryFunction;

    StaticRegistryInitializer(RegistryType& instance,
                              const BackendId& id,
                              FactoryFunction factory)
    {
        instance.Register(id, factory);
    }
};

} // namespace armnn