diff options
author | Finn Williams <Finn.Williams@arm.com> | 2020-04-20 21:21:07 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2020-04-27 20:50:53 +0100 |
commit | 2ed809cb4765306b7af9b6968e2ec609d143979b (patch) | |
tree | a26b4d4e841434802c01b11a202ec58acf3cd61f /tests/profiling/gatordmock | |
parent | 4e755a50e35a1f5ac1b011dc4baf89e6d97f116e (diff) | |
download | armnn-2ed809cb4765306b7af9b6968e2ec609d143979b.tar.gz |
IVGCVSW-4594 Refactor the GatordMockService and GatordMockMain to extract a BasePipeServer
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I03c1b46104dadc491dba6075865e486f78aa60fa
Diffstat (limited to 'tests/profiling/gatordmock')
-rw-r--r-- | tests/profiling/gatordmock/GatordMockMain.cpp | 32 | ||||
-rw-r--r-- | tests/profiling/gatordmock/GatordMockService.cpp | 357 | ||||
-rw-r--r-- | tests/profiling/gatordmock/GatordMockService.hpp | 83 | ||||
-rw-r--r-- | tests/profiling/gatordmock/tests/GatordMockTests.cpp | 41 |
4 files changed, 63 insertions, 450 deletions
diff --git a/tests/profiling/gatordmock/GatordMockMain.cpp b/tests/profiling/gatordmock/GatordMockMain.cpp index 029c58f5e8..0dbddeb048 100644 --- a/tests/profiling/gatordmock/GatordMockMain.cpp +++ b/tests/profiling/gatordmock/GatordMockMain.cpp @@ -5,12 +5,10 @@ #include "CommandFileParser.hpp" #include "CommandLineProcessor.hpp" +#include <ConnectionHandler.hpp> #include "GatordMockService.hpp" -#include <TimelineDecoder.hpp> -#include <iostream> #include <string> -#include <NetworkSockets.hpp> #include <signal.h> using namespace armnn; @@ -24,11 +22,13 @@ void exit_capture(int signum) run = false; } -bool CreateMockService(armnnUtils::Sockets::Socket clientConnection, std::string commandFile, bool isEchoEnabled) +bool CreateMockService(std::unique_ptr<armnnProfiling::BasePipeServer> basePipeServer, + std::string commandFile, + bool isEchoEnabled) { - GatordMockService mockService(clientConnection, isEchoEnabled); + GatordMockService mockService(std::move(basePipeServer), isEchoEnabled); - // Send receive the strweam metadata and send connection ack. + // Send receive the stream metadata and send connection ack. if (!mockService.WaitForStreamMetaData()) { return EXIT_FAILURE; @@ -63,31 +63,21 @@ int main(int argc, char* argv[]) std::vector<std::thread> threads; std::string commandFile = cmdLine.GetCommandFile(); - armnnUtils::Sockets::Initialize(); - armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); - - if (!GatordMockService::OpenListeningSocket(listeningSocket, cmdLine.GetUdsNamespace(), 10)) - { - return EXIT_FAILURE; - } - std::cout << "Bound to UDS namespace: \\0" << cmdLine.GetUdsNamespace() << std::endl; - // make the socket non-blocking so we can exit the loop - armnnUtils::Sockets::SetNonBlocking(listeningSocket); + armnnProfiling::ConnectionHandler connectionHandler(cmdLine.GetUdsNamespace(), true); + while (run) { - armnnUtils::Sockets::Socket clientConnection = - armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC); + auto basePipeServer = connectionHandler.GetNewBasePipeServer(cmdLine.IsEchoEnabled()); - if (clientConnection > 0) + if (basePipeServer != nullptr) { threads.emplace_back( - std::thread(CreateMockService, clientConnection, commandFile, cmdLine.IsEchoEnabled())); + std::thread(CreateMockService, std::move(basePipeServer), commandFile, cmdLine.IsEchoEnabled())); } std::this_thread::sleep_for(std::chrono::milliseconds(100u)); } - armnnUtils::Sockets::Close(listeningSocket); std::for_each(threads.begin(), threads.end(), [](std::thread& t){t.join();}); }
\ No newline at end of file diff --git a/tests/profiling/gatordmock/GatordMockService.cpp b/tests/profiling/gatordmock/GatordMockService.cpp index dbe4fb3b99..13f688225b 100644 --- a/tests/profiling/gatordmock/GatordMockService.cpp +++ b/tests/profiling/gatordmock/GatordMockService.cpp @@ -13,7 +13,6 @@ #include <armnn/utility/Assert.hpp> #include <cerrno> -#include <fcntl.h> #include <iomanip> #include <iostream> #include <string> @@ -26,97 +25,6 @@ namespace armnn namespace gatordmock { -bool GatordMockService::OpenListeningSocket(armnnUtils::Sockets::Socket listeningSocket, - const std::string udsNamespace, - const int numOfConnections) -{ - if (-1 == listeningSocket) - { - std::cerr << ": Socket construction failed: " << strerror(errno) << std::endl; - return false; - } - - sockaddr_un udsAddress; - memset(&udsAddress, 0, sizeof(sockaddr_un)); - // We've set the first element of sun_path to be 0, skip over it and copy the namespace after it. - memcpy(udsAddress.sun_path + 1, udsNamespace.c_str(), strlen(udsNamespace.c_str())); - udsAddress.sun_family = AF_UNIX; - - // Bind the socket to the UDS namespace. - if (-1 == bind(listeningSocket, reinterpret_cast<const sockaddr*>(&udsAddress), sizeof(sockaddr_un))) - { - std::cerr << ": Binding on socket failed: " << strerror(errno) << std::endl; - return false; - } - // Listen for 10 connections. - if (-1 == listen(listeningSocket, numOfConnections)) - { - std::cerr << ": Listen call on socket failed: " << strerror(errno) << std::endl; - return false; - } - return true; -} - -bool GatordMockService::WaitForStreamMetaData() -{ - if (m_EchoPackets) - { - std::cout << "Waiting for stream meta data..." << std::endl; - } - // The start of the stream metadata is 2x32bit words, 0 and packet length. - uint8_t header[8]; - if (!ReadFromSocket(header, 8)) - { - return false; - } - EchoPacket(PacketDirection::ReceivedHeader, header, 8); - // The first word, stream_metadata_identifer, should always be 0. - if (ToUint32(&header[0], TargetEndianness::BeWire) != 0) - { - std::cerr << ": Protocol error. The stream_metadata_identifer was not 0." << std::endl; - return false; - } - - uint8_t pipeMagic[4]; - if (!ReadFromSocket(pipeMagic, 4)) - { - return false; - } - EchoPacket(PacketDirection::ReceivedData, pipeMagic, 4); - - // Before we interpret the length we need to read the pipe_magic word to determine endianness. - if (ToUint32(&pipeMagic[0], TargetEndianness::BeWire) == PIPE_MAGIC) - { - m_Endianness = TargetEndianness::BeWire; - } - else if (ToUint32(&pipeMagic[0], TargetEndianness::LeWire) == PIPE_MAGIC) - { - m_Endianness = TargetEndianness::LeWire; - } - else - { - std::cerr << ": Protocol read error. Unable to read PIPE_MAGIC value." << std::endl; - return false; - } - // Now we know the endianness we can get the length from the header. - // Remember we already read the pipe magic 4 bytes. - uint32_t metaDataLength = ToUint32(&header[4], m_Endianness) - 4; - // Read the entire packet. - std::vector<uint8_t> packetData(metaDataLength); - if (metaDataLength != - boost::numeric_cast<uint32_t>(Sockets::Read(m_ClientConnection, packetData.data(), metaDataLength))) - { - std::cerr << ": Protocol read error. Data length mismatch." << std::endl; - return false; - } - EchoPacket(PacketDirection::ReceivedData, packetData.data(), metaDataLength); - m_StreamMetaDataVersion = ToUint32(&packetData[0], m_Endianness); - m_StreamMetaDataMaxDataLen = ToUint32(&packetData[4], m_Endianness); - m_StreamMetaDataPid = ToUint32(&packetData[8], m_Endianness); - - return true; -} - void GatordMockService::SendConnectionAck() { if (m_EchoPackets) @@ -124,7 +32,7 @@ void GatordMockService::SendConnectionAck() std::cout << "Sending connection acknowledgement." << std::endl; } // The connection ack packet is an empty data packet with packetId == 1. - SendPacket(0, 1, nullptr, 0); + m_BasePipeServer.get()->SendPacket(0, 1, nullptr, 0); } void GatordMockService::SendRequestCounterDir() @@ -134,7 +42,7 @@ void GatordMockService::SendRequestCounterDir() std::cout << "Sending connection acknowledgement." << std::endl; } // The request counter directory packet is an empty data packet with packetId == 3. - SendPacket(0, 3, nullptr, 0); + m_BasePipeServer.get()->SendPacket(0, 3, nullptr, 0); } void GatordMockService::SendActivateTimelinePacket() @@ -144,7 +52,7 @@ void GatordMockService::SendActivateTimelinePacket() std::cout << "Sending activate timeline packet." << std::endl; } // The activate timeline packet is an empty data packet with packetId == 6. - SendPacket(0, 6, nullptr, 0); + m_BasePipeServer.get()->SendPacket(0, 6, nullptr, 0); } void GatordMockService::SendDeactivateTimelinePacket() @@ -154,7 +62,7 @@ void GatordMockService::SendDeactivateTimelinePacket() std::cout << "Sending deactivate timeline packet." << std::endl; } // The deactivate timeline packet is an empty data packet with packetId == 7. - SendPacket(0, 7, nullptr, 0); + m_BasePipeServer.get()->SendPacket(0, 7, nullptr, 0); } bool GatordMockService::LaunchReceivingThread() @@ -164,13 +72,13 @@ bool GatordMockService::LaunchReceivingThread() std::cout << "Launching receiving thread." << std::endl; } // At this point we want to make the socket non blocking. - if (!Sockets::SetNonBlocking(m_ClientConnection)) + if (!m_BasePipeServer.get()->SetNonBlocking()) { - Sockets::Close(m_ClientConnection); + m_BasePipeServer.get()->Close(); std::cerr << "Failed to set socket as non blocking: " << strerror(errno) << std::endl; return false; } - m_ListeningThread = std::thread(&GatordMockService::ReceiveLoop, this, std::ref(*this)); + m_ListeningThread = std::thread(&GatordMockService::ReceiveLoop, this); return true; } @@ -194,6 +102,11 @@ void GatordMockService::WaitForReceivingThread() } } +bool GatordMockService::WaitForStreamMetaData() +{ + return m_BasePipeServer->WaitForStreamMetaData(); +} + void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::vector<uint16_t> counters) { // The packet body consists of a UINT32 representing the period following by zero or more @@ -204,7 +117,6 @@ void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::v { std::cout << "SendPeriodicCounterSelectionList: Period=" << std::dec << period << "uSec" << std::endl; std::cout << "List length=" << counters.size() << std::endl; - ; } // Start by calculating the length of the packet body in bytes. This will be at least 4. uint32_t dataLength = static_cast<uint32_t>(4 + (counters.size() * 2)); @@ -222,7 +134,7 @@ void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::v } // Send the packet. - SendPacket(0, 4, data, dataLength); + m_BasePipeServer.get()->SendPacket(0, 4, data, dataLength); // There will be an echo response packet sitting in the receive thread. PeriodicCounterSelectionResponseHandler // should deal with it. } @@ -245,14 +157,29 @@ void GatordMockService::WaitCommand(uint32_t timeout) } } -void GatordMockService::ReceiveLoop(GatordMockService& mockService) +void GatordMockService::ReceiveLoop() { m_CloseReceivingThread.store(false); while (!m_CloseReceivingThread.load()) { try { - armnn::profiling::Packet packet = mockService.WaitForPacket(500); + profiling::Packet packet = m_BasePipeServer.get()->WaitForPacket(500); + + profiling::PacketVersionResolver packetVersionResolver; + + profiling::Version version = + packetVersionResolver.ResolvePacketVersion(packet.GetPacketFamily(), packet.GetPacketId()); + + profiling::CommandHandlerFunctor* commandHandlerFunctor = m_HandlerRegistry.GetFunctor( + packet.GetPacketFamily(), + packet.GetPacketId(), + version.GetEncodedValue()); + + + + ARMNN_ASSERT(commandHandlerFunctor); + commandHandlerFunctor->operator()(packet); } catch (const armnn::TimeoutException&) { @@ -272,230 +199,6 @@ void GatordMockService::ReceiveLoop(GatordMockService& mockService) } } -armnn::profiling::Packet GatordMockService::WaitForPacket(uint32_t timeoutMs) -{ - // Is there currently more than a headers worth of data waiting to be read? - int bytes_available; - Sockets::Ioctl(m_ClientConnection, FIONREAD, &bytes_available); - if (bytes_available > 8) - { - // Yes there is. Read it: - return ReceivePacket(); - } - else - { - // No there's not. Poll for more data. - struct pollfd pollingFd[1]{}; - pollingFd[0].fd = m_ClientConnection; - int pollResult = Sockets::Poll(pollingFd, 1, static_cast<int>(timeoutMs)); - - switch (pollResult) - { - // Error - case -1: - throw armnn::RuntimeException(std::string("File descriptor reported an error during polling: ") + - strerror(errno)); - - // Timeout - case 0: - throw armnn::TimeoutException("Timeout while waiting to receive packet."); - - // Normal poll return. It could still contain an error signal - default: - // Check if the socket reported an error - if (pollingFd[0].revents & (POLLNVAL | POLLERR | POLLHUP)) - { - if (pollingFd[0].revents == POLLNVAL) - { - throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLNVAL")); - } - if (pollingFd[0].revents == POLLERR) - { - throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLERR: ") + - strerror(errno)); - } - if (pollingFd[0].revents == POLLHUP) - { - throw armnn::RuntimeException(std::string("Connection closed by remote client: POLLHUP")); - } - } - - // Check if there is data to read - if (!(pollingFd[0].revents & (POLLIN))) - { - // This is a corner case. The socket as been woken up but not with any data. - // We'll throw a timeout exception to loop around again. - throw armnn::TimeoutException("File descriptor was polled but no data was available to receive."); - } - return ReceivePacket(); - } - } -} - -armnn::profiling::Packet GatordMockService::ReceivePacket() -{ - uint32_t header[2]; - if (!ReadHeader(header)) - { - return armnn::profiling::Packet(); - } - // Read data_length bytes from the socket. - std::unique_ptr<unsigned char[]> uniquePacketData = std::make_unique<unsigned char[]>(header[1]); - unsigned char* packetData = reinterpret_cast<unsigned char*>(uniquePacketData.get()); - - if (!ReadFromSocket(packetData, header[1])) - { - return armnn::profiling::Packet(); - } - - EchoPacket(PacketDirection::ReceivedData, packetData, header[1]); - - // Construct received packet - armnn::profiling::PacketVersionResolver packetVersionResolver; - armnn::profiling::Packet packetRx = armnn::profiling::Packet(header[0], header[1], uniquePacketData); - if (m_EchoPackets) - { - std::cout << "Processing packet ID= " << packetRx.GetPacketId() << " Length=" << packetRx.GetLength() - << std::endl; - } - - profiling::Version version = - packetVersionResolver.ResolvePacketVersion(packetRx.GetPacketFamily(), packetRx.GetPacketId()); - - profiling::CommandHandlerFunctor* commandHandlerFunctor = - m_HandlerRegistry.GetFunctor(packetRx.GetPacketFamily(), packetRx.GetPacketId(), version.GetEncodedValue()); - ARMNN_ASSERT(commandHandlerFunctor); - commandHandlerFunctor->operator()(packetRx); - return packetRx; -} - -bool GatordMockService::SendPacket(uint32_t packetFamily, uint32_t packetId, const uint8_t* data, uint32_t dataLength) -{ - // Construct a packet from the id and data given and send it to the client. - // Encode the header. - uint32_t header[2]; - header[0] = packetFamily << 26 | packetId << 16; - header[1] = dataLength; - // Add the header to the packet. - std::vector<uint8_t> packet(8 + dataLength); - InsertU32(header[0], packet.data(), m_Endianness); - InsertU32(header[1], packet.data() + 4, m_Endianness); - // And the rest of the data if there is any. - if (dataLength > 0) - { - memcpy((packet.data() + 8), data, dataLength); - } - EchoPacket(PacketDirection::Sending, packet.data(), packet.size()); - if (-1 == Sockets::Write(m_ClientConnection, packet.data(), packet.size())) - { - std::cerr << ": Failure when writing to client socket: " << strerror(errno) << std::endl; - return false; - } - return true; -} - -bool GatordMockService::ReadHeader(uint32_t headerAsWords[2]) -{ - // The header will always be 2x32bit words. - uint8_t header[8]; - if (!ReadFromSocket(header, 8)) - { - return false; - } - EchoPacket(PacketDirection::ReceivedHeader, header, 8); - headerAsWords[0] = ToUint32(&header[0], m_Endianness); - headerAsWords[1] = ToUint32(&header[4], m_Endianness); - return true; -} - -bool GatordMockService::ReadFromSocket(uint8_t* packetData, uint32_t expectedLength) -{ - // This is a blocking read until either expectedLength has been received or an error is detected. - long totalBytesRead = 0; - while (boost::numeric_cast<uint32_t>(totalBytesRead) < expectedLength) - { - long bytesRead = Sockets::Read(m_ClientConnection, packetData, expectedLength); - if (bytesRead < 0) - { - std::cerr << ": Failure when reading from client socket: " << strerror(errno) << std::endl; - return false; - } - if (bytesRead == 0) - { - std::cerr << ": EOF while reading from client socket." << std::endl; - return false; - } - totalBytesRead += bytesRead; - } - return true; -}; - -void GatordMockService::EchoPacket(PacketDirection direction, uint8_t* packet, size_t lengthInBytes) -{ - // If enabled print the contents of the data packet to the console. - if (m_EchoPackets) - { - if (direction == PacketDirection::Sending) - { - std::cout << "TX " << std::dec << lengthInBytes << " bytes : "; - } - else if (direction == PacketDirection::ReceivedHeader) - { - std::cout << "RX Header " << std::dec << lengthInBytes << " bytes : "; - } - else - { - std::cout << "RX Data " << std::dec << lengthInBytes << " bytes : "; - } - for (unsigned int i = 0; i < lengthInBytes; i++) - { - if ((i % 10) == 0) - { - std::cout << std::endl; - } - std::cout << "0x" << std::setfill('0') << std::setw(2) << std::hex << static_cast<unsigned int>(packet[i]) - << " "; - } - std::cout << std::endl; - } -} - -uint32_t GatordMockService::ToUint32(uint8_t* data, TargetEndianness endianness) -{ - // Extract the first 4 bytes starting at data and push them into a 32bit integer based on the - // specified endianness. - if (endianness == TargetEndianness::BeWire) - { - return static_cast<uint32_t>(data[0]) << 24 | static_cast<uint32_t>(data[1]) << 16 | - static_cast<uint32_t>(data[2]) << 8 | static_cast<uint32_t>(data[3]); - } - else - { - return static_cast<uint32_t>(data[3]) << 24 | static_cast<uint32_t>(data[2]) << 16 | - static_cast<uint32_t>(data[1]) << 8 | static_cast<uint32_t>(data[0]); - } -} - -void GatordMockService::InsertU32(uint32_t value, uint8_t* data, TargetEndianness endianness) -{ - // Take the bytes of a 32bit integer and copy them into char array starting at data considering - // the endianness value. - if (endianness == TargetEndianness::BeWire) - { - *data = static_cast<uint8_t>((value >> 24) & 0xFF); - *(data + 1) = static_cast<uint8_t>((value >> 16) & 0xFF); - *(data + 2) = static_cast<uint8_t>((value >> 8) & 0xFF); - *(data + 3) = static_cast<uint8_t>(value & 0xFF); - } - else - { - *(data + 3) = static_cast<uint8_t>((value >> 24) & 0xFF); - *(data + 2) = static_cast<uint8_t>((value >> 16) & 0xFF); - *(data + 1) = static_cast<uint8_t>((value >> 8) & 0xFF); - *data = static_cast<uint8_t>(value & 0xFF); - } -} - } // namespace gatordmock } // namespace armnn diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp index 232d2565e3..8bad41cdfb 100644 --- a/tests/profiling/gatordmock/GatordMockService.hpp +++ b/tests/profiling/gatordmock/GatordMockService.hpp @@ -7,7 +7,6 @@ #include <CommandHandlerRegistry.hpp> #include <Packet.hpp> -#include <NetworkSockets.hpp> #include <atomic> #include <string> @@ -20,6 +19,8 @@ #include "PeriodicCounterCaptureCommandHandler.hpp" #include "StreamMetadataCommandHandler.hpp" +#include <BasePipeServer.hpp> + #include "PacketVersionResolver.hpp" #include "StubCommandHandler.hpp" @@ -29,19 +30,6 @@ namespace armnn namespace gatordmock { -enum class TargetEndianness -{ - BeWire, - LeWire -}; - -enum class PacketDirection -{ - Sending, - ReceivedHeader, - ReceivedData -}; - /// A class that implements a Mock Gatord server. It will listen on a specified Unix domain socket (UDS) /// namespace for client connections. It will then allow opertaions to manage coutners while receiving counter data. class GatordMockService @@ -49,9 +37,8 @@ class GatordMockService public: /// @param registry reference to a command handler registry. /// @param echoPackets if true the raw packets will be printed to stdout. - GatordMockService(armnnUtils::Sockets::Socket clientConnection, bool echoPackets) - : m_ClientConnection(clientConnection) - , m_PacketsReceivedCount(0) + GatordMockService(std::unique_ptr<armnnProfiling::BasePipeServer> clientConnection, bool echoPackets) + : m_BasePipeServer(std::move(clientConnection)) , m_EchoPackets(echoPackets) , m_CloseReceivingThread(false) , m_PacketVersionResolver() @@ -81,18 +68,11 @@ public: m_HandlerRegistry.RegisterFunctor(&m_TimelineCaptureCommandHandler); } - ~GatordMockService() - { - // We have set SOCK_CLOEXEC on these sockets but we'll close them to be good citizens. - armnnUtils::Sockets::Close(m_ClientConnection); - } + GatordMockService(const GatordMockService&) = delete; + GatordMockService& operator=(const GatordMockService&) = delete; - /// Establish the Unix domain socket and set it to listen for connections. - /// @param udsNamespace the namespace (socket address) associated with the listener. - /// @return true only if the socket has been correctly setup. - static bool OpenListeningSocket(armnnUtils::Sockets::Socket listeningSocket, - const std::string udsNamespace, - const int numOfConnections = 1); + GatordMockService(GatordMockService&&) = delete; + GatordMockService& operator=(GatordMockService&&) = delete; /// Once the connection is open wait to receive the stream meta data packet from the client. Reading this /// packet differs from others as we need to determine endianness. @@ -137,21 +117,6 @@ public: /// Execute the WAIT command from the comamnd file. void WaitCommand(uint32_t timeout); - uint32_t GetStreamMetadataVersion() - { - return m_StreamMetaDataVersion; - } - - uint32_t GetStreamMetadataMaxDataLen() - { - return m_StreamMetaDataMaxDataLen; - } - - uint32_t GetStreamMetadataPid() - { - return m_StreamMetaDataPid; - } - profiling::DirectoryCaptureCommandHandler& GetDirectoryCaptureCommandHandler() { return m_DirectoryCaptureCommandHandler; @@ -167,39 +132,11 @@ public: return m_TimelineDirectoryCaptureCommandHandler; } - private: - void ReceiveLoop(GatordMockService& mockService); - - int MainLoop(armnn::profiling::CommandHandlerRegistry& registry, armnnUtils::Sockets::Socket m_ClientConnection); - - /// Block on the client connection until a complete packet has been received. This is a placeholder function to - /// enable early testing of the tool. - /// @return true if a valid packet has been received. - armnn::profiling::Packet WaitForPacket(uint32_t timeoutMs); - - armnn::profiling::Packet ReceivePacket(); - - bool SendPacket(uint32_t packetFamily, uint32_t packetId, const uint8_t* data, uint32_t dataLength); - - void EchoPacket(PacketDirection direction, uint8_t* packet, size_t lengthInBytes); - - bool ReadHeader(uint32_t headerAsWords[2]); - - bool ReadFromSocket(uint8_t* packetData, uint32_t expectedLength); - - uint32_t ToUint32(uint8_t* data, TargetEndianness endianness); - - void InsertU32(uint32_t value, uint8_t* data, TargetEndianness endianness); - - static const uint32_t PIPE_MAGIC = 0x45495434; + void ReceiveLoop(); - TargetEndianness m_Endianness; - uint32_t m_StreamMetaDataVersion; - uint32_t m_StreamMetaDataMaxDataLen; - uint32_t m_StreamMetaDataPid; + std::unique_ptr<armnnProfiling::BasePipeServer> m_BasePipeServer; - armnnUtils::Sockets::Socket m_ClientConnection; std::atomic<uint32_t> m_PacketsReceivedCount; bool m_EchoPackets; diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp index 1929c7aeef..cdedeeb897 100644 --- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp +++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp @@ -4,10 +4,10 @@ // #include <CommandHandlerRegistry.hpp> +#include <ConnectionHandler.hpp> #include <DirectoryCaptureCommandHandler.hpp> -#include <GatordMockService.hpp> +#include <gatordmock/GatordMockService.hpp> #include <LabelsAndEventClasses.hpp> -#include <PeriodicCounterCaptureCommandHandler.hpp> #include <ProfilingService.hpp> #include <TimelinePacketWriterFactory.hpp> @@ -21,6 +21,7 @@ #include <boost/test/test_tools.hpp> #include <boost/test/unit_test_suite.hpp> + BOOST_AUTO_TEST_SUITE(GatordMockTests) using namespace armnn; @@ -229,13 +230,9 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) // Setup the mock service to bind to the UDS. std::string udsNamespace = "gatord_namespace"; - armnnUtils::Sockets::Initialize(); - armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); + BOOST_CHECK_NO_THROW(armnnProfiling::ConnectionHandler connectionHandler(udsNamespace, false)); - if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace)) - { - BOOST_FAIL("Failed to open Listening Socket"); - } + armnnProfiling::ConnectionHandler connectionHandler(udsNamespace, false); // Enable the profiling service. armnn::IRuntime::CreationOptions::ExternalProfilingOptions options; @@ -251,15 +248,11 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) BOOST_CHECK(profilingService.GetCurrentState() == profiling::ProfilingState::NotConnected); profilingService.Update(); - // Connect the profiling service to the mock Gatord. - armnnUtils::Sockets::Socket clientSocket = - armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC); - if (-1 == clientSocket) - { - BOOST_FAIL("Failed to connect client"); - } + // Connect the profiling service + auto basePipeServer = connectionHandler.GetNewBasePipeServer(false); - gatordmock::GatordMockService mockService(clientSocket, false); + // Connect the profiling service to the mock Gatord. + gatordmock::GatordMockService mockService(std::move(basePipeServer), false); timelinedecoder::TimelineDecoder& timelineDecoder = mockService.GetTimelineDecoder(); profiling::DirectoryCaptureCommandHandler& directoryCaptureCommandHandler = @@ -377,7 +370,6 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) mockService.WaitForReceivingThread(); options.m_EnableProfiling = false; profilingService.ResetExternalProfilingOptions(options, true); - armnnUtils::Sockets::Close(listeningSocket); // Future tests here will add counters to the ProfilingService, increment values and examine // PeriodicCounterCapture data received. These are yet to be integrated. } @@ -388,22 +380,15 @@ BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation) // Setup the mock service to bind to the UDS. std::string udsNamespace = "gatord_namespace"; - armnnUtils::Sockets::Initialize(); - armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); - - if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace)) - { - BOOST_FAIL("Failed to open Listening Socket"); - } + armnnProfiling::ConnectionHandler connectionHandler(udsNamespace, false); armnn::IRuntime::CreationOptions options; options.m_ProfilingOptions.m_EnableProfiling = true; options.m_ProfilingOptions.m_TimelineEnabled = true; armnn::Runtime runtime(options); - armnnUtils::Sockets::Socket clientConnection; - clientConnection = armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC); - gatordmock::GatordMockService mockService(clientConnection, false); + auto basePipeServer = connectionHandler.GetNewBasePipeServer(false); + gatordmock::GatordMockService mockService(std::move(basePipeServer), false); // Read the stream metadata on the mock side. if (!mockService.WaitForStreamMetaData()) @@ -484,8 +469,6 @@ BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation) BOOST_CHECK(timelineDecoder.GetModel().m_Events.size() == 0); mockService.WaitForReceivingThread(); - armnnUtils::Sockets::Close(listeningSocket); - GetProfilingService(&runtime).Disconnect(); } |