aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/BackendRegistry.hpp
blob: c13aa9f8b66d38618acf6ab2773ce56bc03a1302 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/Types.hpp>
#include <armnn/BackendId.hpp>
#include <armnn/Optional.hpp>
#include <armnn/backends/ICustomAllocator.hpp>

#include <memory>
#include <unordered_map>
#include <functional>

namespace armnn
{

namespace profiling
{
    class ProfilingService;
}
class IBackendInternal;
using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>;

class BackendRegistry
{
public:
    using PointerType = IBackendInternalUniquePtr;
    using FactoryFunction = std::function<PointerType()>;

    void Register(const BackendId& id, FactoryFunction factory);
    bool IsBackendRegistered(const BackendId& id) const;
    FactoryFunction GetFactory(const BackendId& id) const;
    size_t Size() const;
    BackendIdSet GetBackendIds() const;
    std::string GetBackendIdsAsString() const;
    void SetProfilingService(armnn::Optional<profiling::ProfilingService&> profilingService);
    void RegisterAllocator(const BackendId& id, std::shared_ptr<ICustomAllocator> alloc);
    std::unordered_map<BackendId, std::shared_ptr<ICustomAllocator>> GetAllocators();

    BackendRegistry() {}
    virtual ~BackendRegistry() {}

    struct StaticRegistryInitializer
    {
        StaticRegistryInitializer(BackendRegistry& instance,
                                  const BackendId& id,
                                  FactoryFunction factory)
        {
            instance.Register(id, factory);
        }
    };

    void Deregister(const BackendId& id);
    void DeregisterAllocator(const BackendId &id);

protected:
    using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;

    /// For testing only
    static void Swap(BackendRegistry& instance, FactoryStorage& other);

private:
    BackendRegistry(const BackendRegistry&) = delete;
    BackendRegistry& operator=(const BackendRegistry&) = delete;

    FactoryStorage m_Factories;
    armnn::Optional<profiling::ProfilingService&> m_ProfilingService;
    std::unordered_map<BackendId, std::shared_ptr<ICustomAllocator>> m_CustomMemoryAllocatorMap;
};

BackendRegistry& BackendRegistryInstance();

} // namespace armnn