aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/DynamicBackend.cpp
blob: c576199e1f2e0a120ab95d7ce6335a2e78a627b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "DynamicBackend.hpp"
#include "DynamicBackendUtils.hpp"

namespace armnn
{

DynamicBackend::DynamicBackend(const void* sharedObjectHandle)
    : m_BackendIdFunction(nullptr)
    , m_BackendVersionFunction(nullptr)
    , m_BackendFactoryFunction(nullptr)
    , m_Handle(const_cast<void*>(sharedObjectHandle), &DynamicBackendUtils::CloseHandle)
{
    if (m_Handle == nullptr)
    {
        throw InvalidArgumentException("Cannot create a DynamicBackend object from an invalid shared object handle");
    }

    // These calls will throw in case of error
    m_BackendIdFunction      = SetFunctionPointer<IdFunctionType>("GetBackendId");
    m_BackendVersionFunction = SetFunctionPointer<VersionFunctionType>("GetVersion");
    m_BackendFactoryFunction = SetFunctionPointer<FactoryFunctionType>("BackendFactory");

    // Check that the backend is compatible with the current Backend API
    BackendId backendId = GetBackendId();
    BackendVersion backendVersion = GetBackendVersion();
    if (!DynamicBackendUtils::IsBackendCompatible(backendVersion))
    {
        throw RuntimeException(boost::str(boost::format("The dynamic backend %1% (version %2%) is not compatible"
                                                        "with the current Backend API (vesion %3%)")
                                          % backendId
                                          % backendVersion
                                          % IBackendInternal::GetApiVersion()));
    }
}

BackendId DynamicBackend::GetBackendId()
{
    if (m_BackendIdFunction == nullptr)
    {
        throw RuntimeException("GetBackendId error: invalid function pointer");
    }

    const char* backendId = m_BackendIdFunction();
    if (backendId == nullptr)
    {
        throw RuntimeException("GetBackendId error: invalid backend id");
    }

    return BackendId(backendId);
}

BackendVersion DynamicBackend::GetBackendVersion()
{
    if (m_BackendVersionFunction == nullptr)
    {
        throw RuntimeException("GetBackendVersion error: invalid function pointer");
    }

    uint32_t major = 0;
    uint32_t minor = 0;
    m_BackendVersionFunction(&major, &minor);

    return BackendVersion{ major, minor };
}

IBackendInternalUniquePtr DynamicBackend::GetBackend()
{
    // This call throws in case of error
    return CreateBackend();
}

BackendRegistry::FactoryFunction DynamicBackend::GetFactoryFunction()
{
    if (m_BackendFactoryFunction == nullptr)
    {
        throw RuntimeException("GetFactoryFunction error: invalid function pointer");
    }

    return [this]() -> IBackendInternalUniquePtr
    {
        // This call throws in case of error
        return CreateBackend();
    };
}

template<typename BackendFunctionType>
BackendFunctionType DynamicBackend::SetFunctionPointer(const std::string& backendFunctionName)
{
    if (m_Handle == nullptr)
    {
        throw RuntimeException("SetFunctionPointer error: invalid shared object handle");
    }

    if (backendFunctionName.empty())
    {
        throw RuntimeException("SetFunctionPointer error: backend function name must not be empty");
    }

    // This call will throw in case of error
    auto functionPointer = DynamicBackendUtils::GetEntryPoint<BackendFunctionType>(m_Handle.get(),
                                                                                   backendFunctionName.c_str());
    if (!functionPointer)
    {
        throw RuntimeException("SetFunctionPointer error: invalid backend function pointer returned");
    }

    return functionPointer;
}

IBackendInternalUniquePtr DynamicBackend::CreateBackend()
{
    if (m_BackendFactoryFunction == nullptr)
    {
        throw RuntimeException("CreateBackend error: invalid function pointer");
    }

    auto backendPointer = reinterpret_cast<IBackendInternal*>(m_BackendFactoryFunction());
    if (backendPointer == nullptr)
    {
        throw RuntimeException("CreateBackend error: backend instance must not be null");
    }

    return std::unique_ptr<IBackendInternal>(backendPointer);
}

} // namespace armnn