aboutsummaryrefslogtreecommitdiff
path: root/src/backends/RegistryCommon.hpp
blob: 27663b6deaef1ea2973238377a39ad397fba2409 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/BackendId.hpp>
#include <armnn/Exceptions.hpp>
#include <functional>
#include <memory>
#include <unordered_map>

namespace armnn
{

template <typename RegisteredType>
struct RegisteredTypeName
{
    static const char * Name() { return "UNKNOWN"; }
};

template <typename RegisteredType, typename PointerType>
class RegistryCommon
{
public:
    using FactoryFunction = std::function<PointerType()>;

    void Register(const BackendId& id, FactoryFunction factory)
    {
        if (m_Factories.count(id) > 0)
        {
            throw InvalidArgumentException(
                std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory",
                CHECK_LOCATION());
        }

        m_Factories[id] = factory;
    }

    FactoryFunction GetFactory(const BackendId& id) const
    {
        auto it = m_Factories.find(id);
        if (it == m_Factories.end())
        {
            throw InvalidArgumentException(
                std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered",
                CHECK_LOCATION());
        }

        return it->second;
    }

    size_t Size() const
    {
        return m_Factories.size();
    }

    BackendIdSet GetBackendIds() const
    {
        BackendIdSet result;
        for (const auto& it : m_Factories)
        {
            result.insert(it.first);
        }
        return result;
    }

    RegistryCommon() {}
    virtual ~RegistryCommon() {}

protected:
    using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;

    // For testing only
    static void Swap(RegistryCommon& instance, FactoryStorage& other)
    {
        std::swap(instance.m_Factories, other);
    }

private:
    RegistryCommon(const RegistryCommon&) = delete;
    RegistryCommon& operator=(const RegistryCommon&) = delete;

    FactoryStorage m_Factories;
};

template <typename RegistryType>
struct StaticRegistryInitializer
{
    using FactoryFunction = typename RegistryType::FactoryFunction;

    StaticRegistryInitializer(RegistryType& instance,
                              const BackendId& id,
                              FactoryFunction factory)
    {
        instance.Register(id, factory);
    }
};

} // namespace armnn