aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Threadpool.hpp
blob: e2458dbb658e15a5a2934b9756bc7f512f950283 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>

#include "INetwork.hpp"
#include "IRuntime.hpp"

#include <thread>
#include <mutex>
#include <condition_variable>
#include <unordered_map>
#include <queue>

namespace armnn
{
namespace experimental
{
class Threadpool
{
public:
    Threadpool(std::size_t numThreads,
               IRuntime* runtimePtr,
               std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);

    ~Threadpool()
    {
        TerminateThreadPool();
    }

    void LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
    void UnloadMemHandles(NetworkId networkId);

    /// Schedule an asynchronous execution on the loaded network
    void Schedule(NetworkId networkId,
                  const InputTensors &inputTensors,
                  const OutputTensors &outputTensors,
                  const QosExecPriority priority,
                  std::shared_ptr<IAsyncExecutionCallback> cb);

    void TerminateThreadPool() noexcept;

private:
    using ExecutionTuple = std::tuple<NetworkId,
                                      InputTensors,
                                      OutputTensors,
                                      std::shared_ptr<IAsyncExecutionCallback>>;

    using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;

    void ProcessExecPriorities(uint32_t index);

    IRuntime* m_RuntimePtr;

    ExecutionQueue m_HighPriorityQueue;
    ExecutionQueue m_MediumPriorityQueue;
    ExecutionQueue m_LowPriorityQueue;

    // Condition Variables require mutex which will guard the shared state.
    // Has an event happened? Stop signal for example
    std::condition_variable m_ThreadPoolEvent;
    std::mutex m_ThreadPoolMutex;

    // The shared state for conditional variable
    bool m_TerminatePool = false;

    std::unordered_map<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> m_WorkingMemHandleMap;
    std::vector<std::unique_ptr<std::thread>> m_Threads;
};

} // namespace experimental

} // namespace armnn