diff options
Diffstat (limited to 'src/profiling/FileOnlyProfilingConnection.cpp')
-rw-r--r-- | src/profiling/FileOnlyProfilingConnection.cpp | 162 |
1 files changed, 157 insertions, 5 deletions
diff --git a/src/profiling/FileOnlyProfilingConnection.cpp b/src/profiling/FileOnlyProfilingConnection.cpp index f9bdde961f..5947d2c081 100644 --- a/src/profiling/FileOnlyProfilingConnection.cpp +++ b/src/profiling/FileOnlyProfilingConnection.cpp @@ -8,6 +8,7 @@ #include <armnn/Exceptions.hpp> +#include <algorithm> #include <boost/numeric/conversion/cast.hpp> #include <iostream> #include <thread> @@ -32,10 +33,19 @@ bool FileOnlyProfilingConnection::IsOpen() const void FileOnlyProfilingConnection::Close() { // Dump any unread packets out of the queue. - for (unsigned int i = 0; i < m_PacketQueue.size(); i++) + size_t initialSize = m_PacketQueue.size(); + for (size_t i = 0; i < initialSize; ++i) { m_PacketQueue.pop(); } + // dispose of the processing thread + m_KeepRunning.store(false); + if (m_LocalHandlersThread.joinable()) + { + // make sure the thread wakes up and sees it has to stop + m_ConditionPacketReadable.notify_one(); + m_LocalHandlersThread.join(); + } } bool FileOnlyProfilingConnection::WaitForStreamMeta(const unsigned char* buffer, uint32_t length) @@ -112,10 +122,11 @@ bool FileOnlyProfilingConnection::SendCounterSelectionPacket() bool FileOnlyProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length) { ARMNN_ASSERT(buffer); + Packet packet = ReceivePacket(buffer, length); // Read Header and determine case uint32_t outgoingHeaderAsWords[2]; - PackageActivity packageActivity = GetPackageActivity(buffer, outgoingHeaderAsWords); + PackageActivity packageActivity = GetPackageActivity(packet, outgoingHeaderAsWords); switch (packageActivity) { @@ -160,6 +171,7 @@ bool FileOnlyProfilingConnection::WritePacket(const unsigned char* buffer, uint3 break; } } + ForwardPacketToHandlers(packet); return true; } @@ -181,10 +193,10 @@ Packet FileOnlyProfilingConnection::ReadPacket(uint32_t timeout) return returnedPacket; } -PackageActivity FileOnlyProfilingConnection::GetPackageActivity(const unsigned char* buffer, uint32_t headerAsWords[2]) +PackageActivity FileOnlyProfilingConnection::GetPackageActivity(const Packet& packet, uint32_t headerAsWords[2]) { - headerAsWords[0] = ToUint32(buffer, m_Endianness); - headerAsWords[1] = ToUint32(buffer + 4, m_Endianness); + headerAsWords[0] = packet.GetHeader(); + headerAsWords[1] = packet.GetLength(); if (headerAsWords[0] == 0x20000) // Packet family = 0 Packet Id = 2 { return PackageActivity::CounterDirectory; @@ -221,6 +233,146 @@ void FileOnlyProfilingConnection::Fail(const std::string& errorMessage) throw RuntimeException(errorMessage); } +/// Adds a local packet handler to the FileOnlyProfilingConnection. Invoking this will start +/// a processing thread that will ensure that processing of packets will happen on a separate +/// thread from the profiling services send thread and will therefore protect against the +/// profiling message buffer becoming exhausted because packet handling slows the dispatch. +void FileOnlyProfilingConnection::AddLocalPacketHandler(ILocalPacketHandlerSharedPtr localPacketHandler) +{ + m_PacketHandlers.push_back(std::move(localPacketHandler)); + ILocalPacketHandlerSharedPtr localCopy = m_PacketHandlers.back(); + localCopy->SetConnection(this); + if (localCopy->GetHeadersAccepted().empty()) + { + //this is a universal handler + m_UniversalHandlers.push_back(localCopy); + } + else + { + for (uint32_t header : localCopy->GetHeadersAccepted()) + { + auto iter = m_IndexedHandlers.find(header); + if (iter == m_IndexedHandlers.end()) + { + std::vector<ILocalPacketHandlerSharedPtr> handlers; + handlers.push_back(localCopy); + m_IndexedHandlers.emplace(std::make_pair(header, handlers)); + } + else + { + iter->second.push_back(localCopy); + } + } + } +} + +void FileOnlyProfilingConnection::StartProcessingThread() +{ + // check if the thread has already started + if (m_IsRunning.load()) + { + return; + } + // make sure if there was one running before it is joined + if (m_LocalHandlersThread.joinable()) + { + m_LocalHandlersThread.join(); + } + m_IsRunning.store(true); + m_KeepRunning.store(true); + m_LocalHandlersThread = std::thread(&FileOnlyProfilingConnection::ServiceLocalHandlers, this); +} + +void FileOnlyProfilingConnection::ForwardPacketToHandlers(Packet& packet) +{ + if (m_PacketHandlers.empty()) + { + return; + } + if (m_KeepRunning.load() == false) + { + return; + } + { + std::unique_lock<std::mutex> readableListLock(m_ReadableMutex); + if (m_KeepRunning.load() == false) + { + return; + } + m_ReadableList.push(std::move(packet)); + } + m_ConditionPacketReadable.notify_one(); +} + +void FileOnlyProfilingConnection::ServiceLocalHandlers() +{ + do + { + Packet returnedPacket; + bool readPacket = false; + { // only lock while we are taking the packet off the incoming list + std::unique_lock<std::mutex> lck(m_ReadableMutex); + if (m_Timeout < 0) + { + m_ConditionPacketReadable.wait(lck, + [&] { return !m_ReadableList.empty(); }); + } + else + { + m_ConditionPacketReadable.wait_for(lck, + std::chrono::milliseconds(std::max(m_Timeout, 1000)), + [&] { return !m_ReadableList.empty(); }); + } + if (m_KeepRunning.load()) + { + if (!m_ReadableList.empty()) + { + returnedPacket = std::move(m_ReadableList.front()); + m_ReadableList.pop(); + readPacket = true; + } + } + else + { + ClearReadableList(); + } + } + if (m_KeepRunning.load() && readPacket) + { + DispatchPacketToHandlers(returnedPacket); + } + } while (m_KeepRunning.load()); + // make sure the readable list is cleared + ClearReadableList(); + m_IsRunning.store(false); +} + +void FileOnlyProfilingConnection::ClearReadableList() +{ + // make sure the incoming packet queue gets emptied + size_t initialSize = m_ReadableList.size(); + for (size_t i = 0; i < initialSize; ++i) + { + m_ReadableList.pop(); + } +} + +void FileOnlyProfilingConnection::DispatchPacketToHandlers(const Packet& packet) +{ + for (auto& delegate : m_UniversalHandlers) + { + delegate->HandlePacket(packet); + } + auto iter = m_IndexedHandlers.find(packet.GetHeader()); + if (iter != m_IndexedHandlers.end()) + { + for (auto &delegate : iter->second) + { + delegate->HandlePacket(packet); + } + } +} + } // namespace profiling } // namespace armnn |