aboutsummaryrefslogtreecommitdiff
path: root/src/profiling/SendCounterPacket.hpp
blob: 11587552b86f0cd0b88859c2b91608b63d9b8251 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "IBufferManager.hpp"
#include "ICounterDirectory.hpp"
#include "ISendCounterPacket.hpp"
#include "IProfilingConnection.hpp"
#include "ProfilingStateMachine.hpp"
#include "ProfilingUtils.hpp"

#include <atomic>
#include <condition_variable>
#include <mutex>
#include <thread>
#include <type_traits>

namespace armnn
{

namespace profiling
{

class SendCounterPacket : public ISendCounterPacket
{
public:
    using CategoryRecord        = std::vector<uint32_t>;
    using DeviceRecord          = std::vector<uint32_t>;
    using CounterSetRecord      = std::vector<uint32_t>;
    using EventRecord           = std::vector<uint32_t>;
    using IndexValuePairsVector = std::vector<std::pair<uint16_t, uint32_t>>;

    SendCounterPacket(ProfilingStateMachine& profilingStateMachine, IBufferManager& buffer, int timeout = 1000)
        : m_StateMachine(profilingStateMachine)
        , m_BufferManager(buffer)
        , m_Timeout(timeout)
        , m_IsRunning(false)
        , m_KeepRunning(false)
        , m_SendThreadException(nullptr)
    {}
    ~SendCounterPacket()
    {
        // Don't rethrow when destructing the object
        Stop(false);
    }

    void SendStreamMetaDataPacket() override;

    void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory) override;

    void SendPeriodicCounterCapturePacket(uint64_t timestamp, const IndexValuePairsVector& values) override;

    void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
                                            const std::vector<uint16_t>& selectedCounterIds) override;

    void SetReadyToRead() override;

    static const unsigned int PIPE_MAGIC = 0x45495434;

    void Start(IProfilingConnection& profilingConnection);
    void Stop(bool rethrowSendThreadExceptions = true);
    bool IsRunning() { return m_IsRunning.load(); }

    void WaitForPacketSent(uint32_t timeout = 1000)
    {
        std::unique_lock<std::mutex> lock(m_PacketSentWaitMutex);
        // Blocks until notified that at least a packet has been sent or until timeout expires.
        m_PacketSentWaitCondition.wait_for(lock, std::chrono::milliseconds(timeout));
    }

private:
    void Send(IProfilingConnection& profilingConnection);

    template <typename ExceptionType>
    void CancelOperationAndThrow(const std::string& errorMessage)
    {
        // Throw a runtime exception with the given error message
        throw ExceptionType(errorMessage);
    }

    template <typename ExceptionType>
    void CancelOperationAndThrow(IPacketBufferPtr& writerBuffer, const std::string& errorMessage)
    {
        if (std::is_same<ExceptionType, armnn::profiling::BufferExhaustion>::value)
        {
            SetReadyToRead();
        }

        if (writerBuffer != nullptr)
        {
            // Cancel the operation
            m_BufferManager.Release(writerBuffer);
        }

        // Throw a runtime exception with the given error message
        throw ExceptionType(errorMessage);
    }

    void FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers = true);

    ProfilingStateMachine& m_StateMachine;
    IBufferManager& m_BufferManager;
    int m_Timeout;
    std::mutex m_WaitMutex;
    std::condition_variable m_WaitCondition;
    std::thread m_SendThread;
    std::atomic<bool> m_IsRunning;
    std::atomic<bool> m_KeepRunning;
    // m_ReadyToRead will be protected by m_WaitMutex
    bool m_ReadyToRead;
    std::exception_ptr m_SendThreadException;
    std::mutex m_PacketSentWaitMutex;
    std::condition_variable m_PacketSentWaitCondition;

protected:
    // Helper methods, protected for testing
    bool CreateCategoryRecord(const CategoryPtr& category,
                              const Counters& counters,
                              CategoryRecord& categoryRecord,
                              std::string& errorMessage);
    bool CreateDeviceRecord(const DevicePtr& device,
                            DeviceRecord& deviceRecord,
                            std::string& errorMessage);
    bool CreateCounterSetRecord(const CounterSetPtr& counterSet,
                                CounterSetRecord& counterSetRecord,
                                std::string& errorMessage);
    bool CreateEventRecord(const CounterPtr& counter,
                           EventRecord& eventRecord,
                           std::string& errorMessage);
};

} // namespace profiling

} // namespace armnn