From 2ed809cb4765306b7af9b6968e2ec609d143979b Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Mon, 20 Apr 2020 21:21:07 +0100 Subject: IVGCVSW-4594 Refactor the GatordMockService and GatordMockMain to extract a BasePipeServer Signed-off-by: Finn Williams Change-Id: I03c1b46104dadc491dba6075865e486f78aa60fa --- CMakeLists.txt | 25 +- cmake/GlobalConfig.cmake | 1 + .../server/src/basePipeServer/BasePipeServer.cpp | 293 +++++++++++++++++ .../server/src/basePipeServer/BasePipeServer.hpp | 117 +++++++ profiling/server/src/basePipeServer/CMakeLists.txt | 25 ++ .../src/basePipeServer/ConnectionHandler.cpp | 55 ++++ .../src/basePipeServer/ConnectionHandler.hpp | 45 +++ .../basePipeServer/tests/BasePipeServerTests.cpp | 99 ++++++ src/profiling/PacketVersionResolver.cpp | 2 +- src/profiling/PacketVersionResolver.hpp | 2 +- tests/profiling/gatordmock/GatordMockMain.cpp | 32 +- tests/profiling/gatordmock/GatordMockService.cpp | 357 ++------------------- tests/profiling/gatordmock/GatordMockService.hpp | 83 +---- .../profiling/gatordmock/tests/GatordMockTests.cpp | 41 +-- 14 files changed, 724 insertions(+), 453 deletions(-) create mode 100644 profiling/server/src/basePipeServer/BasePipeServer.cpp create mode 100644 profiling/server/src/basePipeServer/BasePipeServer.hpp create mode 100644 profiling/server/src/basePipeServer/CMakeLists.txt create mode 100644 profiling/server/src/basePipeServer/ConnectionHandler.cpp create mode 100644 profiling/server/src/basePipeServer/ConnectionHandler.hpp create mode 100644 profiling/server/src/basePipeServer/tests/BasePipeServerTests.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 54376b6109..586d64c89c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -905,6 +905,12 @@ if(BUILD_UNIT_TESTS) ) endif() + if(BUILD_BASE_PIPE_SERVER) + list(APPEND unittest_sources + profiling/server/src/basePipeServer/tests/BasePipeServerTests.cpp + ) + endif() + foreach(lib ${armnnUnitTestLibraries}) message(STATUS "Adding object library dependency to UnitTests: ${lib}") list(APPEND unittest_sources $) @@ -970,6 +976,10 @@ if(BUILD_UNIT_TESTS) target_link_libraries(UnitTests armnnOnnxParser) endif() + if(BUILD_BASE_PIPE_SERVER) + target_link_libraries(UnitTests armnnBasePipeServer) + endif() + addDllCopyCommands(UnitTests) endif() @@ -1008,10 +1018,22 @@ if (BUILD_ARMNN_SERIALIZER AND (BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD addDllCopyCommands(ArmnnConverter) endif() +if(BUILD_BASE_PIPE_SERVER) + add_subdirectory(profiling/server/src/basePipeServer) +endif() + if(BUILD_TIMELINE_DECODER) add_subdirectory(src/timelineDecoder) endif() +if(BUILD_GATORD_MOCK AND NOT BUILD_BASE_PIPE_SERVER) + message(ERROR, "In order to build GatordMock you must set BUILD_BASE_PIPE_SERVER = YES") +endif() + +if(BUILD_GATORD_MOCK AND NOT BUILD_TIMELINE_DECODER) + message(ERROR, "In order to build GatordMock you must set BUILD_TIMELINE_DECODER = YES") +endif() + if(BUILD_GATORD_MOCK) set(gatord_mock_sources) list(APPEND gatord_mock_sources @@ -1030,7 +1052,7 @@ if(BUILD_GATORD_MOCK) tests/profiling/gatordmock/StubCommandHandler.hpp ) - include_directories(src/profiling tests/profiling tests/profiling/gatordmock src/timelineDecoder) + include_directories(src/profiling src/timelineDecoder profiling/server/src/basePipeServer tests/profiling) add_library_ex(gatordMockService STATIC ${gatord_mock_sources}) target_include_directories(gatordMockService PRIVATE src/armnnUtils) @@ -1040,6 +1062,7 @@ if(BUILD_GATORD_MOCK) target_link_libraries(GatordMock armnn + armnnBasePipeServer timelineDecoder gatordMockService ${Boost_PROGRAM_OPTIONS_LIBRARY} diff --git a/cmake/GlobalConfig.cmake b/cmake/GlobalConfig.cmake index 0df7cd408e..08cbb1b3d0 100644 --- a/cmake/GlobalConfig.cmake +++ b/cmake/GlobalConfig.cmake @@ -24,6 +24,7 @@ option(DYNAMIC_BACKEND_PATHS "Colon seperated list of paths where to load the dy option(BUILD_GATORD_MOCK "Build the Gatord simulator for external profiling testing." ON) option(BUILD_TIMELINE_DECODER "Build the Timeline Decoder for external profiling." ON) option(SHARED_BOOST "Use dynamic linking for boost libraries" OFF) +option(BUILD_BASE_PIPE_SERVER "Build the server to handle external profiling pipe traffic" ON) include(SelectLibraryConfigurations) diff --git a/profiling/server/src/basePipeServer/BasePipeServer.cpp b/profiling/server/src/basePipeServer/BasePipeServer.cpp new file mode 100644 index 0000000000..fde5684160 --- /dev/null +++ b/profiling/server/src/basePipeServer/BasePipeServer.cpp @@ -0,0 +1,293 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include +#include "BasePipeServer.hpp" + +using namespace armnnUtils; + +namespace armnnProfiling +{ + +bool BasePipeServer::ReadFromSocket(uint8_t* packetData, uint32_t expectedLength) +{ + // This is a blocking read until either expectedLength has been received or an error is detected. + long totalBytesRead = 0; + while (boost::numeric_cast(totalBytesRead) < expectedLength) + { + long bytesRead = Sockets::Read(m_ClientConnection, packetData, expectedLength); + if (bytesRead < 0) + { + std::cerr << ": Failure when reading from client socket: " << strerror(errno) << std::endl; + return false; + } + if (bytesRead == 0) + { + std::cerr << ": EOF while reading from client socket." << std::endl; + return false; + } + totalBytesRead += bytesRead; + } + return true; +}; + +bool BasePipeServer::WaitForStreamMetaData() +{ + if (m_EchoPackets) + { + std::cout << "Waiting for stream meta data..." << std::endl; + } + // The start of the stream metadata is 2x32bit words, 0 and packet length. + uint8_t header[8]; + if (!ReadFromSocket(header, 8)) + { + return false; + } + EchoPacket(PacketDirection::ReceivedHeader, header, 8); + // The first word, stream_metadata_identifer, should always be 0. + if (ToUint32(&header[0], TargetEndianness::BeWire) != 0) + { + std::cerr << ": Protocol error. The stream_metadata_identifer was not 0." << std::endl; + return false; + } + + uint8_t pipeMagic[4]; + if (!ReadFromSocket(pipeMagic, 4)) + { + return false; + } + EchoPacket(PacketDirection::ReceivedData, pipeMagic, 4); + + // Before we interpret the length we need to read the pipe_magic word to determine endianness. + if (ToUint32(&pipeMagic[0], TargetEndianness::BeWire) == PIPE_MAGIC) + { + m_Endianness = TargetEndianness::BeWire; + } + else if (ToUint32(&pipeMagic[0], TargetEndianness::LeWire) == PIPE_MAGIC) + { + m_Endianness = TargetEndianness::LeWire; + } + else + { + std::cerr << ": Protocol read error. Unable to read PIPE_MAGIC value." << std::endl; + return false; + } + // Now we know the endianness we can get the length from the header. + // Remember we already read the pipe magic 4 bytes. + uint32_t metaDataLength = ToUint32(&header[4], m_Endianness) - 4; + // Read the entire packet. + std::vector packetData(metaDataLength); + if (metaDataLength != + boost::numeric_cast(Sockets::Read(m_ClientConnection, packetData.data(), metaDataLength))) + { + std::cerr << ": Protocol read error. Data length mismatch." << std::endl; + return false; + } + EchoPacket(PacketDirection::ReceivedData, packetData.data(), metaDataLength); + m_StreamMetaDataVersion = ToUint32(&packetData[0], m_Endianness); + m_StreamMetaDataMaxDataLen = ToUint32(&packetData[4], m_Endianness); + m_StreamMetaDataPid = ToUint32(&packetData[8], m_Endianness); + + return true; +} + +armnn::profiling::Packet BasePipeServer::WaitForPacket(uint32_t timeoutMs) +{ + // Is there currently more than a headers worth of data waiting to be read? + int bytes_available; + Sockets::Ioctl(m_ClientConnection, FIONREAD, &bytes_available); + if (bytes_available > 8) + { + // Yes there is. Read it: + return ReceivePacket(); + } + else + { + // No there's not. Poll for more data. + struct pollfd pollingFd[1]{}; + pollingFd[0].fd = m_ClientConnection; + int pollResult = Sockets::Poll(pollingFd, 1, static_cast(timeoutMs)); + + switch (pollResult) + { + // Error + case -1: + throw armnn::RuntimeException(std::string("File descriptor reported an error during polling: ") + + strerror(errno)); + + // Timeout + case 0: + throw armnn::TimeoutException("Timeout while waiting to receive packet."); + + // Normal poll return. It could still contain an error signal + default: + // Check if the socket reported an error + if (pollingFd[0].revents & (POLLNVAL | POLLERR | POLLHUP)) + { + if (pollingFd[0].revents == POLLNVAL) + { + throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLNVAL")); + } + if (pollingFd[0].revents == POLLERR) + { + throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLERR: ") + + strerror(errno)); + } + if (pollingFd[0].revents == POLLHUP) + { + throw armnn::RuntimeException(std::string("Connection closed by remote client: POLLHUP")); + } + } + + // Check if there is data to read + if (!(pollingFd[0].revents & (POLLIN))) + { + // This is a corner case. The socket as been woken up but not with any data. + // We'll throw a timeout exception to loop around again. + throw armnn::TimeoutException("File descriptor was polled but no data was available to receive."); + } + return ReceivePacket(); + } + } +} + +armnn::profiling::Packet BasePipeServer::ReceivePacket() +{ + uint32_t header[2]; + if (!ReadHeader(header)) + { + return armnn::profiling::Packet(); + } + // Read data_length bytes from the socket. + std::unique_ptr uniquePacketData = std::make_unique(header[1]); + unsigned char* packetData = reinterpret_cast(uniquePacketData.get()); + + if (!ReadFromSocket(packetData, header[1])) + { + return armnn::profiling::Packet(); + } + + EchoPacket(PacketDirection::ReceivedData, packetData, header[1]); + + // Construct received packet + armnn::profiling::Packet packetRx = armnn::profiling::Packet(header[0], header[1], uniquePacketData); + if (m_EchoPackets) + { + std::cout << "Processing packet ID= " << packetRx.GetPacketId() << " Length=" << packetRx.GetLength() + << std::endl; + } + + return packetRx; +} + +bool BasePipeServer::SendPacket(uint32_t packetFamily, uint32_t packetId, const uint8_t* data, uint32_t dataLength) +{ + // Construct a packet from the id and data given and send it to the client. + // Encode the header. + uint32_t header[2]; + header[0] = packetFamily << 26 | packetId << 16; + header[1] = dataLength; + // Add the header to the packet. + std::vector packet(8 + dataLength); + InsertU32(header[0], packet.data(), m_Endianness); + InsertU32(header[1], packet.data() + 4, m_Endianness); + // And the rest of the data if there is any. + if (dataLength > 0) + { + memcpy((packet.data() + 8), data, dataLength); + } + EchoPacket(PacketDirection::Sending, packet.data(), packet.size()); + if (-1 == armnnUtils::Sockets::Write(m_ClientConnection, packet.data(), packet.size())) + { + std::cerr << ": Failure when writing to client socket: " << strerror(errno) << std::endl; + return false; + } + return true; +} + +bool BasePipeServer::ReadHeader(uint32_t headerAsWords[2]) +{ + // The header will always be 2x32bit words. + uint8_t header[8]; + if (!ReadFromSocket(header, 8)) + { + return false; + } + EchoPacket(PacketDirection::ReceivedHeader, header, 8); + headerAsWords[0] = ToUint32(&header[0], m_Endianness); + headerAsWords[1] = ToUint32(&header[4], m_Endianness); + return true; +} + +void BasePipeServer::EchoPacket(PacketDirection direction, uint8_t* packet, size_t lengthInBytes) +{ + // If enabled print the contents of the data packet to the console. + if (m_EchoPackets) + { + if (direction == PacketDirection::Sending) + { + std::cout << "TX " << std::dec << lengthInBytes << " bytes : "; + } + else if (direction == PacketDirection::ReceivedHeader) + { + std::cout << "RX Header " << std::dec << lengthInBytes << " bytes : "; + } + else + { + std::cout << "RX Data " << std::dec << lengthInBytes << " bytes : "; + } + for (unsigned int i = 0; i < lengthInBytes; i++) + { + if ((i % 10) == 0) + { + std::cout << std::endl; + } + std::cout << "0x" << std::setfill('0') << std::setw(2) << std::hex << static_cast(packet[i]) + << " "; + } + std::cout << std::endl; + } +} + +uint32_t BasePipeServer::ToUint32(uint8_t* data, TargetEndianness endianness) +{ + // Extract the first 4 bytes starting at data and push them into a 32bit integer based on the + // specified endianness. + if (endianness == TargetEndianness::BeWire) + { + return static_cast(data[0]) << 24 | static_cast(data[1]) << 16 | + static_cast(data[2]) << 8 | static_cast(data[3]); + } + else + { + return static_cast(data[3]) << 24 | static_cast(data[2]) << 16 | + static_cast(data[1]) << 8 | static_cast(data[0]); + } +} + +void BasePipeServer::InsertU32(uint32_t value, uint8_t* data, TargetEndianness endianness) +{ + // Take the bytes of a 32bit integer and copy them into char array starting at data considering + // the endianness value. + if (endianness == TargetEndianness::BeWire) + { + *data = static_cast((value >> 24) & 0xFF); + *(data + 1) = static_cast((value >> 16) & 0xFF); + *(data + 2) = static_cast((value >> 8) & 0xFF); + *(data + 3) = static_cast(value & 0xFF); + } + else + { + *(data + 3) = static_cast((value >> 24) & 0xFF); + *(data + 2) = static_cast((value >> 16) & 0xFF); + *(data + 1) = static_cast((value >> 8) & 0xFF); + *data = static_cast(value & 0xFF); + } +} + +} // namespace armnnProfiling \ No newline at end of file diff --git a/profiling/server/src/basePipeServer/BasePipeServer.hpp b/profiling/server/src/basePipeServer/BasePipeServer.hpp new file mode 100644 index 0000000000..a150d76278 --- /dev/null +++ b/profiling/server/src/basePipeServer/BasePipeServer.hpp @@ -0,0 +1,117 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +//#include +#include +#include +#include "../../../../src/armnnUtils/NetworkSockets.hpp" +#include "../../../../src/profiling/Packet.hpp" +#include "common/include/SocketConnectionException.hpp" + +namespace armnnProfiling +{ + +enum class TargetEndianness +{ + BeWire, + LeWire +}; + +enum class PacketDirection +{ + Sending, + ReceivedHeader, + ReceivedData +}; +class ConnectionHandler; + +class BasePipeServer +{ + +public: + + BasePipeServer(armnnUtils::Sockets::Socket clientConnection, bool echoPackets) + : m_ClientConnection(clientConnection) + , m_EchoPackets(echoPackets) + {} + + ~BasePipeServer() + { + // We have set SOCK_CLOEXEC on this socket but we'll close it to be good citizens. + armnnUtils::Sockets::Close(m_ClientConnection); + } + + BasePipeServer(const BasePipeServer&) = delete; + BasePipeServer& operator=(const BasePipeServer&) = delete; + + BasePipeServer(BasePipeServer&&) = delete; + BasePipeServer& operator=(BasePipeServer&&) = delete; + + /// Close the client connection + /// @return 0 if successful + int Close() + { + return armnnUtils::Sockets::Close(m_ClientConnection); + } + + /// Send a packet to the client + /// @return true if a valid packet has been sent. + bool SendPacket(uint32_t packetFamily, uint32_t packetId, const uint8_t* data, uint32_t dataLength); + + /// Set the client socket to nonblocking + /// @return true if successful. + bool SetNonBlocking() + { + return armnnUtils::Sockets::SetNonBlocking(m_ClientConnection); + } + + /// Block on the client connection until a complete packet has been received. + /// @return true if a valid packet has been received. + armnn::profiling::Packet WaitForPacket(uint32_t timeoutMs); + + /// Once the connection is open wait to receive the stream meta data packet from the client. Reading this + /// packet differs from others as we need to determine endianness. + /// @return true only if a valid stream meta data packet has been received. + bool WaitForStreamMetaData(); + + uint32_t GetStreamMetadataVersion() + { + return m_StreamMetaDataVersion; + } + + uint32_t GetStreamMetadataMaxDataLen() + { + return m_StreamMetaDataMaxDataLen; + } + + uint32_t GetStreamMetadataPid() + { + return m_StreamMetaDataPid; + } + +private: + + void EchoPacket(PacketDirection direction, uint8_t* packet, size_t lengthInBytes); + bool ReadFromSocket(uint8_t* packetData, uint32_t expectedLength); + bool ReadHeader(uint32_t headerAsWords[2]); + + armnn::profiling::Packet ReceivePacket(); + + uint32_t ToUint32(uint8_t* data, TargetEndianness endianness); + void InsertU32(uint32_t value, uint8_t* data, TargetEndianness endianness); + + armnnUtils::Sockets::Socket m_ClientConnection; + bool m_EchoPackets; + TargetEndianness m_Endianness; + static const uint32_t PIPE_MAGIC = 0x45495434; + + uint32_t m_StreamMetaDataVersion; + uint32_t m_StreamMetaDataMaxDataLen; + uint32_t m_StreamMetaDataPid; +}; + +} // namespace armnnProfiling \ No newline at end of file diff --git a/profiling/server/src/basePipeServer/CMakeLists.txt b/profiling/server/src/basePipeServer/CMakeLists.txt new file mode 100644 index 0000000000..e535cf2e66 --- /dev/null +++ b/profiling/server/src/basePipeServer/CMakeLists.txt @@ -0,0 +1,25 @@ +# +# Copyright © 2020 Arm Ltd. All rights reserved. +# SPDX-License-Identifier: MIT +# + +if(BUILD_BASE_PIPE_SERVER) + set(BasePipeServer_sources) + list(APPEND BasePipeServer_sources + BasePipeServer.cpp + BasePipeServer.hpp + ConnectionHandler.cpp + ConnectionHandler.hpp + ) + + include_directories(src/armnnUtils src/profiling) + + add_library_ex(armnnBasePipeServer SHARED ${BasePipeServer_sources}) + set_target_properties(armnnBasePipeServer PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) + set_target_properties(armnnBasePipeServer PROPERTIES VERSION ${GENERIC_LIB_VERSION} + SOVERSION ${GENERIC_LIB_SOVERSION}) + + install(TARGETS armnnBasePipeServer + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() \ No newline at end of file diff --git a/profiling/server/src/basePipeServer/ConnectionHandler.cpp b/profiling/server/src/basePipeServer/ConnectionHandler.cpp new file mode 100644 index 0000000000..69ccd01050 --- /dev/null +++ b/profiling/server/src/basePipeServer/ConnectionHandler.cpp @@ -0,0 +1,55 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include "ConnectionHandler.hpp" + +using namespace armnnUtils; + +namespace armnnProfiling +{ +ConnectionHandler::ConnectionHandler(const std::string& udsNamespace, const bool setNonBlocking) +{ + Sockets::Initialize(); + m_ListeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); + + if (-1 == m_ListeningSocket) + { + throw SocketConnectionException(": Socket construction failed: ", 1, 1); + } + + sockaddr_un udsAddress; + memset(&udsAddress, 0, sizeof(sockaddr_un)); + // We've set the first element of sun_path to be 0, skip over it and copy the namespace after it. + memcpy(udsAddress.sun_path + 1, udsNamespace.c_str(), strlen(udsNamespace.c_str())); + udsAddress.sun_family = AF_UNIX; + + // Bind the socket to the UDS namespace. + if (-1 == bind(m_ListeningSocket, reinterpret_cast(&udsAddress), sizeof(sockaddr_un))) + { + throw SocketConnectionException(": Binding on socket failed: ", m_ListeningSocket, errno); + } + // Listen for connections. + if (-1 == listen(m_ListeningSocket, 1)) + { + throw SocketConnectionException(": Listen call on socket failed: ", m_ListeningSocket, errno); + } + + if (setNonBlocking) + { + Sockets::SetNonBlocking(m_ListeningSocket); + } +} + +std::unique_ptr ConnectionHandler::GetNewBasePipeServer(const bool echoPackets) +{ + armnnUtils::Sockets::Socket clientConnection = armnnUtils::Sockets::Accept(m_ListeningSocket, nullptr, nullptr, + SOCK_CLOEXEC); + if (clientConnection < 1) + { + return nullptr; + } + return std::make_unique(clientConnection, echoPackets); +} + +} // namespace armnnProfiling \ No newline at end of file diff --git a/profiling/server/src/basePipeServer/ConnectionHandler.hpp b/profiling/server/src/basePipeServer/ConnectionHandler.hpp new file mode 100644 index 0000000000..e7317dc355 --- /dev/null +++ b/profiling/server/src/basePipeServer/ConnectionHandler.hpp @@ -0,0 +1,45 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "../../../../src/armnnUtils/NetworkSockets.hpp" +#include "BasePipeServer.hpp" +#include + +namespace armnnProfiling +{ + +class ConnectionHandler +{ +public: + /// Constructor establishes the Unix domain socket and sets it to listen for connections. + /// @param udsNamespace the namespace (socket address) associated with the listener. + /// @throws SocketConnectionException if the socket has been incorrectly setup. + ConnectionHandler(const std::string& udsNamespace, const bool setNonBlocking); + + ~ConnectionHandler() + { + // We have set SOCK_CLOEXEC on this socket but we'll close it to be good citizens. + armnnUtils::Sockets::Close(m_ListeningSocket); + } + + ConnectionHandler(const ConnectionHandler&) = delete; + ConnectionHandler& operator=(const ConnectionHandler&) = delete; + + ConnectionHandler(ConnectionHandler&&) = delete; + ConnectionHandler& operator=(ConnectionHandler&&) = delete; + + /// Attempt to open a new socket to the client and use it to construct a new basePipeServer + /// @param echoPackets if true the raw packets will be printed to stdout. + /// @return if successful a unique_ptr to a basePipeServer otherwise a nullptr + std::unique_ptr GetNewBasePipeServer(const bool echoPackets); + +private: + + armnnUtils::Sockets::Socket m_ListeningSocket; +}; + +} // namespace armnnProfiling \ No newline at end of file diff --git a/profiling/server/src/basePipeServer/tests/BasePipeServerTests.cpp b/profiling/server/src/basePipeServer/tests/BasePipeServerTests.cpp new file mode 100644 index 0000000000..c85bbe72d3 --- /dev/null +++ b/profiling/server/src/basePipeServer/tests/BasePipeServerTests.cpp @@ -0,0 +1,99 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "../ConnectionHandler.hpp" + +#include +#include + +#include +#include + + +BOOST_AUTO_TEST_SUITE(BasePipeServerTests) + +using namespace armnn; +using namespace armnnProfiling; + +BOOST_AUTO_TEST_CASE(BasePipeServerTest) +{ + // Setup the mock service to bind to the UDS. + std::string udsNamespace = "gatord_namespace"; + + // Try to initialize a listening socket through the ConnectionHandler + BOOST_CHECK_NO_THROW(ConnectionHandler connectionHandler(udsNamespace, true)); + + // The socket should close once we leave the scope of BOOST_CHECK_NO_THROW + // and socketProfilingConnection should fail to connect + BOOST_CHECK_THROW(profiling::SocketProfilingConnection socketProfilingConnection, + armnnProfiling::SocketConnectionException); + + // Try to initialize a listening socket through the ConnectionHandler again + ConnectionHandler connectionHandler(udsNamespace, true); + // socketProfilingConnection should connect now + profiling::SocketProfilingConnection socketProfilingConnection; + BOOST_TEST(socketProfilingConnection.IsOpen()); + + auto basePipeServer = connectionHandler.GetNewBasePipeServer(false); + // GetNewBasePipeServer will return null if it fails to create a socket + BOOST_TEST(basePipeServer.get()); + + profiling::BufferManager bufferManager; + profiling::SendCounterPacket sendCounterPacket(bufferManager); + + // Check that we can receive a StreamMetaDataPacket + sendCounterPacket.SendStreamMetaDataPacket(); + + auto packetBuffer = bufferManager.GetReadableBuffer(); + const unsigned char* readBuffer = packetBuffer->GetReadableData(); + unsigned int readBufferSize = packetBuffer->GetSize(); + + BOOST_TEST(readBuffer); + BOOST_TEST(readBufferSize > 0); + + socketProfilingConnection.WritePacket(readBuffer,readBufferSize); + bufferManager.MarkRead(packetBuffer); + + BOOST_TEST(basePipeServer.get()->WaitForStreamMetaData()); + BOOST_TEST(basePipeServer.get()->GetStreamMetadataPid() == armnnUtils::Processes::GetCurrentId()); + BOOST_TEST(basePipeServer.get()->GetStreamMetadataMaxDataLen() == MAX_METADATA_PACKET_LENGTH); + + // Now try a simple PeriodicCounterSelectionPacket + sendCounterPacket.SendPeriodicCounterSelectionPacket(50, {1,2,3,4,5}); + + packetBuffer = bufferManager.GetReadableBuffer(); + readBuffer = packetBuffer->GetReadableData(); + readBufferSize = packetBuffer->GetSize(); + + BOOST_TEST(readBuffer); + BOOST_TEST(readBufferSize > 0); + + socketProfilingConnection.WritePacket(readBuffer,readBufferSize); + bufferManager.MarkRead(packetBuffer); + + auto packet1 = basePipeServer.get()->WaitForPacket(500); + + BOOST_TEST(!packet1.IsEmpty()); + BOOST_TEST(packet1.GetPacketFamily() == 0); + BOOST_TEST(packet1.GetPacketId() == 4); + BOOST_TEST(packet1.GetLength() == 14); + + // Try and send the packet back to the client + basePipeServer.get()->SendPacket(packet1.GetPacketFamily(), + packet1.GetPacketId(), + packet1.GetData(), + packet1.GetLength()); + + auto packet2 = socketProfilingConnection.ReadPacket(500); + + BOOST_TEST(!packet2.IsEmpty()); + BOOST_TEST(packet2.GetPacketFamily() == 0); + BOOST_TEST(packet2.GetPacketId() == 4); + BOOST_TEST(packet2.GetLength() == 14); + + socketProfilingConnection.Close(); +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/src/profiling/PacketVersionResolver.cpp b/src/profiling/PacketVersionResolver.cpp index 4178ace166..689abbb082 100644 --- a/src/profiling/PacketVersionResolver.cpp +++ b/src/profiling/PacketVersionResolver.cpp @@ -60,7 +60,7 @@ Version PacketVersionResolver::ResolvePacketVersion(uint32_t familyId, uint32_t { return Version(1, 1, 0); } - if( packetKey == DectivateTimeLinePacket ) + if( packetKey == DeactivateTimeLinePacket ) { return Version(1, 1, 0); } diff --git a/src/profiling/PacketVersionResolver.hpp b/src/profiling/PacketVersionResolver.hpp index 6222eb02e8..3112f5eac0 100644 --- a/src/profiling/PacketVersionResolver.hpp +++ b/src/profiling/PacketVersionResolver.hpp @@ -34,7 +34,7 @@ private: }; static const PacketKey ActivateTimeLinePacket(0 , 6); -static const PacketKey DectivateTimeLinePacket(0 , 7); +static const PacketKey DeactivateTimeLinePacket(0 , 7); class PacketVersionResolver final { diff --git a/tests/profiling/gatordmock/GatordMockMain.cpp b/tests/profiling/gatordmock/GatordMockMain.cpp index 029c58f5e8..0dbddeb048 100644 --- a/tests/profiling/gatordmock/GatordMockMain.cpp +++ b/tests/profiling/gatordmock/GatordMockMain.cpp @@ -5,12 +5,10 @@ #include "CommandFileParser.hpp" #include "CommandLineProcessor.hpp" +#include #include "GatordMockService.hpp" -#include -#include #include -#include #include using namespace armnn; @@ -24,11 +22,13 @@ void exit_capture(int signum) run = false; } -bool CreateMockService(armnnUtils::Sockets::Socket clientConnection, std::string commandFile, bool isEchoEnabled) +bool CreateMockService(std::unique_ptr basePipeServer, + std::string commandFile, + bool isEchoEnabled) { - GatordMockService mockService(clientConnection, isEchoEnabled); + GatordMockService mockService(std::move(basePipeServer), isEchoEnabled); - // Send receive the strweam metadata and send connection ack. + // Send receive the stream metadata and send connection ack. if (!mockService.WaitForStreamMetaData()) { return EXIT_FAILURE; @@ -63,31 +63,21 @@ int main(int argc, char* argv[]) std::vector threads; std::string commandFile = cmdLine.GetCommandFile(); - armnnUtils::Sockets::Initialize(); - armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); - - if (!GatordMockService::OpenListeningSocket(listeningSocket, cmdLine.GetUdsNamespace(), 10)) - { - return EXIT_FAILURE; - } - std::cout << "Bound to UDS namespace: \\0" << cmdLine.GetUdsNamespace() << std::endl; - // make the socket non-blocking so we can exit the loop - armnnUtils::Sockets::SetNonBlocking(listeningSocket); + armnnProfiling::ConnectionHandler connectionHandler(cmdLine.GetUdsNamespace(), true); + while (run) { - armnnUtils::Sockets::Socket clientConnection = - armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC); + auto basePipeServer = connectionHandler.GetNewBasePipeServer(cmdLine.IsEchoEnabled()); - if (clientConnection > 0) + if (basePipeServer != nullptr) { threads.emplace_back( - std::thread(CreateMockService, clientConnection, commandFile, cmdLine.IsEchoEnabled())); + std::thread(CreateMockService, std::move(basePipeServer), commandFile, cmdLine.IsEchoEnabled())); } std::this_thread::sleep_for(std::chrono::milliseconds(100u)); } - armnnUtils::Sockets::Close(listeningSocket); std::for_each(threads.begin(), threads.end(), [](std::thread& t){t.join();}); } \ No newline at end of file diff --git a/tests/profiling/gatordmock/GatordMockService.cpp b/tests/profiling/gatordmock/GatordMockService.cpp index dbe4fb3b99..13f688225b 100644 --- a/tests/profiling/gatordmock/GatordMockService.cpp +++ b/tests/profiling/gatordmock/GatordMockService.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include #include @@ -26,97 +25,6 @@ namespace armnn namespace gatordmock { -bool GatordMockService::OpenListeningSocket(armnnUtils::Sockets::Socket listeningSocket, - const std::string udsNamespace, - const int numOfConnections) -{ - if (-1 == listeningSocket) - { - std::cerr << ": Socket construction failed: " << strerror(errno) << std::endl; - return false; - } - - sockaddr_un udsAddress; - memset(&udsAddress, 0, sizeof(sockaddr_un)); - // We've set the first element of sun_path to be 0, skip over it and copy the namespace after it. - memcpy(udsAddress.sun_path + 1, udsNamespace.c_str(), strlen(udsNamespace.c_str())); - udsAddress.sun_family = AF_UNIX; - - // Bind the socket to the UDS namespace. - if (-1 == bind(listeningSocket, reinterpret_cast(&udsAddress), sizeof(sockaddr_un))) - { - std::cerr << ": Binding on socket failed: " << strerror(errno) << std::endl; - return false; - } - // Listen for 10 connections. - if (-1 == listen(listeningSocket, numOfConnections)) - { - std::cerr << ": Listen call on socket failed: " << strerror(errno) << std::endl; - return false; - } - return true; -} - -bool GatordMockService::WaitForStreamMetaData() -{ - if (m_EchoPackets) - { - std::cout << "Waiting for stream meta data..." << std::endl; - } - // The start of the stream metadata is 2x32bit words, 0 and packet length. - uint8_t header[8]; - if (!ReadFromSocket(header, 8)) - { - return false; - } - EchoPacket(PacketDirection::ReceivedHeader, header, 8); - // The first word, stream_metadata_identifer, should always be 0. - if (ToUint32(&header[0], TargetEndianness::BeWire) != 0) - { - std::cerr << ": Protocol error. The stream_metadata_identifer was not 0." << std::endl; - return false; - } - - uint8_t pipeMagic[4]; - if (!ReadFromSocket(pipeMagic, 4)) - { - return false; - } - EchoPacket(PacketDirection::ReceivedData, pipeMagic, 4); - - // Before we interpret the length we need to read the pipe_magic word to determine endianness. - if (ToUint32(&pipeMagic[0], TargetEndianness::BeWire) == PIPE_MAGIC) - { - m_Endianness = TargetEndianness::BeWire; - } - else if (ToUint32(&pipeMagic[0], TargetEndianness::LeWire) == PIPE_MAGIC) - { - m_Endianness = TargetEndianness::LeWire; - } - else - { - std::cerr << ": Protocol read error. Unable to read PIPE_MAGIC value." << std::endl; - return false; - } - // Now we know the endianness we can get the length from the header. - // Remember we already read the pipe magic 4 bytes. - uint32_t metaDataLength = ToUint32(&header[4], m_Endianness) - 4; - // Read the entire packet. - std::vector packetData(metaDataLength); - if (metaDataLength != - boost::numeric_cast(Sockets::Read(m_ClientConnection, packetData.data(), metaDataLength))) - { - std::cerr << ": Protocol read error. Data length mismatch." << std::endl; - return false; - } - EchoPacket(PacketDirection::ReceivedData, packetData.data(), metaDataLength); - m_StreamMetaDataVersion = ToUint32(&packetData[0], m_Endianness); - m_StreamMetaDataMaxDataLen = ToUint32(&packetData[4], m_Endianness); - m_StreamMetaDataPid = ToUint32(&packetData[8], m_Endianness); - - return true; -} - void GatordMockService::SendConnectionAck() { if (m_EchoPackets) @@ -124,7 +32,7 @@ void GatordMockService::SendConnectionAck() std::cout << "Sending connection acknowledgement." << std::endl; } // The connection ack packet is an empty data packet with packetId == 1. - SendPacket(0, 1, nullptr, 0); + m_BasePipeServer.get()->SendPacket(0, 1, nullptr, 0); } void GatordMockService::SendRequestCounterDir() @@ -134,7 +42,7 @@ void GatordMockService::SendRequestCounterDir() std::cout << "Sending connection acknowledgement." << std::endl; } // The request counter directory packet is an empty data packet with packetId == 3. - SendPacket(0, 3, nullptr, 0); + m_BasePipeServer.get()->SendPacket(0, 3, nullptr, 0); } void GatordMockService::SendActivateTimelinePacket() @@ -144,7 +52,7 @@ void GatordMockService::SendActivateTimelinePacket() std::cout << "Sending activate timeline packet." << std::endl; } // The activate timeline packet is an empty data packet with packetId == 6. - SendPacket(0, 6, nullptr, 0); + m_BasePipeServer.get()->SendPacket(0, 6, nullptr, 0); } void GatordMockService::SendDeactivateTimelinePacket() @@ -154,7 +62,7 @@ void GatordMockService::SendDeactivateTimelinePacket() std::cout << "Sending deactivate timeline packet." << std::endl; } // The deactivate timeline packet is an empty data packet with packetId == 7. - SendPacket(0, 7, nullptr, 0); + m_BasePipeServer.get()->SendPacket(0, 7, nullptr, 0); } bool GatordMockService::LaunchReceivingThread() @@ -164,13 +72,13 @@ bool GatordMockService::LaunchReceivingThread() std::cout << "Launching receiving thread." << std::endl; } // At this point we want to make the socket non blocking. - if (!Sockets::SetNonBlocking(m_ClientConnection)) + if (!m_BasePipeServer.get()->SetNonBlocking()) { - Sockets::Close(m_ClientConnection); + m_BasePipeServer.get()->Close(); std::cerr << "Failed to set socket as non blocking: " << strerror(errno) << std::endl; return false; } - m_ListeningThread = std::thread(&GatordMockService::ReceiveLoop, this, std::ref(*this)); + m_ListeningThread = std::thread(&GatordMockService::ReceiveLoop, this); return true; } @@ -194,6 +102,11 @@ void GatordMockService::WaitForReceivingThread() } } +bool GatordMockService::WaitForStreamMetaData() +{ + return m_BasePipeServer->WaitForStreamMetaData(); +} + void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::vector counters) { // The packet body consists of a UINT32 representing the period following by zero or more @@ -204,7 +117,6 @@ void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::v { std::cout << "SendPeriodicCounterSelectionList: Period=" << std::dec << period << "uSec" << std::endl; std::cout << "List length=" << counters.size() << std::endl; - ; } // Start by calculating the length of the packet body in bytes. This will be at least 4. uint32_t dataLength = static_cast(4 + (counters.size() * 2)); @@ -222,7 +134,7 @@ void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::v } // Send the packet. - SendPacket(0, 4, data, dataLength); + m_BasePipeServer.get()->SendPacket(0, 4, data, dataLength); // There will be an echo response packet sitting in the receive thread. PeriodicCounterSelectionResponseHandler // should deal with it. } @@ -245,14 +157,29 @@ void GatordMockService::WaitCommand(uint32_t timeout) } } -void GatordMockService::ReceiveLoop(GatordMockService& mockService) +void GatordMockService::ReceiveLoop() { m_CloseReceivingThread.store(false); while (!m_CloseReceivingThread.load()) { try { - armnn::profiling::Packet packet = mockService.WaitForPacket(500); + profiling::Packet packet = m_BasePipeServer.get()->WaitForPacket(500); + + profiling::PacketVersionResolver packetVersionResolver; + + profiling::Version version = + packetVersionResolver.ResolvePacketVersion(packet.GetPacketFamily(), packet.GetPacketId()); + + profiling::CommandHandlerFunctor* commandHandlerFunctor = m_HandlerRegistry.GetFunctor( + packet.GetPacketFamily(), + packet.GetPacketId(), + version.GetEncodedValue()); + + + + ARMNN_ASSERT(commandHandlerFunctor); + commandHandlerFunctor->operator()(packet); } catch (const armnn::TimeoutException&) { @@ -272,230 +199,6 @@ void GatordMockService::ReceiveLoop(GatordMockService& mockService) } } -armnn::profiling::Packet GatordMockService::WaitForPacket(uint32_t timeoutMs) -{ - // Is there currently more than a headers worth of data waiting to be read? - int bytes_available; - Sockets::Ioctl(m_ClientConnection, FIONREAD, &bytes_available); - if (bytes_available > 8) - { - // Yes there is. Read it: - return ReceivePacket(); - } - else - { - // No there's not. Poll for more data. - struct pollfd pollingFd[1]{}; - pollingFd[0].fd = m_ClientConnection; - int pollResult = Sockets::Poll(pollingFd, 1, static_cast(timeoutMs)); - - switch (pollResult) - { - // Error - case -1: - throw armnn::RuntimeException(std::string("File descriptor reported an error during polling: ") + - strerror(errno)); - - // Timeout - case 0: - throw armnn::TimeoutException("Timeout while waiting to receive packet."); - - // Normal poll return. It could still contain an error signal - default: - // Check if the socket reported an error - if (pollingFd[0].revents & (POLLNVAL | POLLERR | POLLHUP)) - { - if (pollingFd[0].revents == POLLNVAL) - { - throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLNVAL")); - } - if (pollingFd[0].revents == POLLERR) - { - throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLERR: ") + - strerror(errno)); - } - if (pollingFd[0].revents == POLLHUP) - { - throw armnn::RuntimeException(std::string("Connection closed by remote client: POLLHUP")); - } - } - - // Check if there is data to read - if (!(pollingFd[0].revents & (POLLIN))) - { - // This is a corner case. The socket as been woken up but not with any data. - // We'll throw a timeout exception to loop around again. - throw armnn::TimeoutException("File descriptor was polled but no data was available to receive."); - } - return ReceivePacket(); - } - } -} - -armnn::profiling::Packet GatordMockService::ReceivePacket() -{ - uint32_t header[2]; - if (!ReadHeader(header)) - { - return armnn::profiling::Packet(); - } - // Read data_length bytes from the socket. - std::unique_ptr uniquePacketData = std::make_unique(header[1]); - unsigned char* packetData = reinterpret_cast(uniquePacketData.get()); - - if (!ReadFromSocket(packetData, header[1])) - { - return armnn::profiling::Packet(); - } - - EchoPacket(PacketDirection::ReceivedData, packetData, header[1]); - - // Construct received packet - armnn::profiling::PacketVersionResolver packetVersionResolver; - armnn::profiling::Packet packetRx = armnn::profiling::Packet(header[0], header[1], uniquePacketData); - if (m_EchoPackets) - { - std::cout << "Processing packet ID= " << packetRx.GetPacketId() << " Length=" << packetRx.GetLength() - << std::endl; - } - - profiling::Version version = - packetVersionResolver.ResolvePacketVersion(packetRx.GetPacketFamily(), packetRx.GetPacketId()); - - profiling::CommandHandlerFunctor* commandHandlerFunctor = - m_HandlerRegistry.GetFunctor(packetRx.GetPacketFamily(), packetRx.GetPacketId(), version.GetEncodedValue()); - ARMNN_ASSERT(commandHandlerFunctor); - commandHandlerFunctor->operator()(packetRx); - return packetRx; -} - -bool GatordMockService::SendPacket(uint32_t packetFamily, uint32_t packetId, const uint8_t* data, uint32_t dataLength) -{ - // Construct a packet from the id and data given and send it to the client. - // Encode the header. - uint32_t header[2]; - header[0] = packetFamily << 26 | packetId << 16; - header[1] = dataLength; - // Add the header to the packet. - std::vector packet(8 + dataLength); - InsertU32(header[0], packet.data(), m_Endianness); - InsertU32(header[1], packet.data() + 4, m_Endianness); - // And the rest of the data if there is any. - if (dataLength > 0) - { - memcpy((packet.data() + 8), data, dataLength); - } - EchoPacket(PacketDirection::Sending, packet.data(), packet.size()); - if (-1 == Sockets::Write(m_ClientConnection, packet.data(), packet.size())) - { - std::cerr << ": Failure when writing to client socket: " << strerror(errno) << std::endl; - return false; - } - return true; -} - -bool GatordMockService::ReadHeader(uint32_t headerAsWords[2]) -{ - // The header will always be 2x32bit words. - uint8_t header[8]; - if (!ReadFromSocket(header, 8)) - { - return false; - } - EchoPacket(PacketDirection::ReceivedHeader, header, 8); - headerAsWords[0] = ToUint32(&header[0], m_Endianness); - headerAsWords[1] = ToUint32(&header[4], m_Endianness); - return true; -} - -bool GatordMockService::ReadFromSocket(uint8_t* packetData, uint32_t expectedLength) -{ - // This is a blocking read until either expectedLength has been received or an error is detected. - long totalBytesRead = 0; - while (boost::numeric_cast(totalBytesRead) < expectedLength) - { - long bytesRead = Sockets::Read(m_ClientConnection, packetData, expectedLength); - if (bytesRead < 0) - { - std::cerr << ": Failure when reading from client socket: " << strerror(errno) << std::endl; - return false; - } - if (bytesRead == 0) - { - std::cerr << ": EOF while reading from client socket." << std::endl; - return false; - } - totalBytesRead += bytesRead; - } - return true; -}; - -void GatordMockService::EchoPacket(PacketDirection direction, uint8_t* packet, size_t lengthInBytes) -{ - // If enabled print the contents of the data packet to the console. - if (m_EchoPackets) - { - if (direction == PacketDirection::Sending) - { - std::cout << "TX " << std::dec << lengthInBytes << " bytes : "; - } - else if (direction == PacketDirection::ReceivedHeader) - { - std::cout << "RX Header " << std::dec << lengthInBytes << " bytes : "; - } - else - { - std::cout << "RX Data " << std::dec << lengthInBytes << " bytes : "; - } - for (unsigned int i = 0; i < lengthInBytes; i++) - { - if ((i % 10) == 0) - { - std::cout << std::endl; - } - std::cout << "0x" << std::setfill('0') << std::setw(2) << std::hex << static_cast(packet[i]) - << " "; - } - std::cout << std::endl; - } -} - -uint32_t GatordMockService::ToUint32(uint8_t* data, TargetEndianness endianness) -{ - // Extract the first 4 bytes starting at data and push them into a 32bit integer based on the - // specified endianness. - if (endianness == TargetEndianness::BeWire) - { - return static_cast(data[0]) << 24 | static_cast(data[1]) << 16 | - static_cast(data[2]) << 8 | static_cast(data[3]); - } - else - { - return static_cast(data[3]) << 24 | static_cast(data[2]) << 16 | - static_cast(data[1]) << 8 | static_cast(data[0]); - } -} - -void GatordMockService::InsertU32(uint32_t value, uint8_t* data, TargetEndianness endianness) -{ - // Take the bytes of a 32bit integer and copy them into char array starting at data considering - // the endianness value. - if (endianness == TargetEndianness::BeWire) - { - *data = static_cast((value >> 24) & 0xFF); - *(data + 1) = static_cast((value >> 16) & 0xFF); - *(data + 2) = static_cast((value >> 8) & 0xFF); - *(data + 3) = static_cast(value & 0xFF); - } - else - { - *(data + 3) = static_cast((value >> 24) & 0xFF); - *(data + 2) = static_cast((value >> 16) & 0xFF); - *(data + 1) = static_cast((value >> 8) & 0xFF); - *data = static_cast(value & 0xFF); - } -} - } // namespace gatordmock } // namespace armnn diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp index 232d2565e3..8bad41cdfb 100644 --- a/tests/profiling/gatordmock/GatordMockService.hpp +++ b/tests/profiling/gatordmock/GatordMockService.hpp @@ -7,7 +7,6 @@ #include #include -#include #include #include @@ -20,6 +19,8 @@ #include "PeriodicCounterCaptureCommandHandler.hpp" #include "StreamMetadataCommandHandler.hpp" +#include + #include "PacketVersionResolver.hpp" #include "StubCommandHandler.hpp" @@ -29,19 +30,6 @@ namespace armnn namespace gatordmock { -enum class TargetEndianness -{ - BeWire, - LeWire -}; - -enum class PacketDirection -{ - Sending, - ReceivedHeader, - ReceivedData -}; - /// A class that implements a Mock Gatord server. It will listen on a specified Unix domain socket (UDS) /// namespace for client connections. It will then allow opertaions to manage coutners while receiving counter data. class GatordMockService @@ -49,9 +37,8 @@ class GatordMockService public: /// @param registry reference to a command handler registry. /// @param echoPackets if true the raw packets will be printed to stdout. - GatordMockService(armnnUtils::Sockets::Socket clientConnection, bool echoPackets) - : m_ClientConnection(clientConnection) - , m_PacketsReceivedCount(0) + GatordMockService(std::unique_ptr clientConnection, bool echoPackets) + : m_BasePipeServer(std::move(clientConnection)) , m_EchoPackets(echoPackets) , m_CloseReceivingThread(false) , m_PacketVersionResolver() @@ -81,18 +68,11 @@ public: m_HandlerRegistry.RegisterFunctor(&m_TimelineCaptureCommandHandler); } - ~GatordMockService() - { - // We have set SOCK_CLOEXEC on these sockets but we'll close them to be good citizens. - armnnUtils::Sockets::Close(m_ClientConnection); - } + GatordMockService(const GatordMockService&) = delete; + GatordMockService& operator=(const GatordMockService&) = delete; - /// Establish the Unix domain socket and set it to listen for connections. - /// @param udsNamespace the namespace (socket address) associated with the listener. - /// @return true only if the socket has been correctly setup. - static bool OpenListeningSocket(armnnUtils::Sockets::Socket listeningSocket, - const std::string udsNamespace, - const int numOfConnections = 1); + GatordMockService(GatordMockService&&) = delete; + GatordMockService& operator=(GatordMockService&&) = delete; /// Once the connection is open wait to receive the stream meta data packet from the client. Reading this /// packet differs from others as we need to determine endianness. @@ -137,21 +117,6 @@ public: /// Execute the WAIT command from the comamnd file. void WaitCommand(uint32_t timeout); - uint32_t GetStreamMetadataVersion() - { - return m_StreamMetaDataVersion; - } - - uint32_t GetStreamMetadataMaxDataLen() - { - return m_StreamMetaDataMaxDataLen; - } - - uint32_t GetStreamMetadataPid() - { - return m_StreamMetaDataPid; - } - profiling::DirectoryCaptureCommandHandler& GetDirectoryCaptureCommandHandler() { return m_DirectoryCaptureCommandHandler; @@ -167,39 +132,11 @@ public: return m_TimelineDirectoryCaptureCommandHandler; } - private: - void ReceiveLoop(GatordMockService& mockService); - - int MainLoop(armnn::profiling::CommandHandlerRegistry& registry, armnnUtils::Sockets::Socket m_ClientConnection); - - /// Block on the client connection until a complete packet has been received. This is a placeholder function to - /// enable early testing of the tool. - /// @return true if a valid packet has been received. - armnn::profiling::Packet WaitForPacket(uint32_t timeoutMs); - - armnn::profiling::Packet ReceivePacket(); - - bool SendPacket(uint32_t packetFamily, uint32_t packetId, const uint8_t* data, uint32_t dataLength); - - void EchoPacket(PacketDirection direction, uint8_t* packet, size_t lengthInBytes); - - bool ReadHeader(uint32_t headerAsWords[2]); - - bool ReadFromSocket(uint8_t* packetData, uint32_t expectedLength); - - uint32_t ToUint32(uint8_t* data, TargetEndianness endianness); - - void InsertU32(uint32_t value, uint8_t* data, TargetEndianness endianness); - - static const uint32_t PIPE_MAGIC = 0x45495434; + void ReceiveLoop(); - TargetEndianness m_Endianness; - uint32_t m_StreamMetaDataVersion; - uint32_t m_StreamMetaDataMaxDataLen; - uint32_t m_StreamMetaDataPid; + std::unique_ptr m_BasePipeServer; - armnnUtils::Sockets::Socket m_ClientConnection; std::atomic m_PacketsReceivedCount; bool m_EchoPackets; diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp index 1929c7aeef..cdedeeb897 100644 --- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp +++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp @@ -4,10 +4,10 @@ // #include +#include #include -#include +#include #include -#include #include #include @@ -21,6 +21,7 @@ #include #include + BOOST_AUTO_TEST_SUITE(GatordMockTests) using namespace armnn; @@ -229,13 +230,9 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) // Setup the mock service to bind to the UDS. std::string udsNamespace = "gatord_namespace"; - armnnUtils::Sockets::Initialize(); - armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); + BOOST_CHECK_NO_THROW(armnnProfiling::ConnectionHandler connectionHandler(udsNamespace, false)); - if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace)) - { - BOOST_FAIL("Failed to open Listening Socket"); - } + armnnProfiling::ConnectionHandler connectionHandler(udsNamespace, false); // Enable the profiling service. armnn::IRuntime::CreationOptions::ExternalProfilingOptions options; @@ -251,15 +248,11 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) BOOST_CHECK(profilingService.GetCurrentState() == profiling::ProfilingState::NotConnected); profilingService.Update(); - // Connect the profiling service to the mock Gatord. - armnnUtils::Sockets::Socket clientSocket = - armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC); - if (-1 == clientSocket) - { - BOOST_FAIL("Failed to connect client"); - } + // Connect the profiling service + auto basePipeServer = connectionHandler.GetNewBasePipeServer(false); - gatordmock::GatordMockService mockService(clientSocket, false); + // Connect the profiling service to the mock Gatord. + gatordmock::GatordMockService mockService(std::move(basePipeServer), false); timelinedecoder::TimelineDecoder& timelineDecoder = mockService.GetTimelineDecoder(); profiling::DirectoryCaptureCommandHandler& directoryCaptureCommandHandler = @@ -377,7 +370,6 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) mockService.WaitForReceivingThread(); options.m_EnableProfiling = false; profilingService.ResetExternalProfilingOptions(options, true); - armnnUtils::Sockets::Close(listeningSocket); // Future tests here will add counters to the ProfilingService, increment values and examine // PeriodicCounterCapture data received. These are yet to be integrated. } @@ -388,22 +380,15 @@ BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation) // Setup the mock service to bind to the UDS. std::string udsNamespace = "gatord_namespace"; - armnnUtils::Sockets::Initialize(); - armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); - - if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace)) - { - BOOST_FAIL("Failed to open Listening Socket"); - } + armnnProfiling::ConnectionHandler connectionHandler(udsNamespace, false); armnn::IRuntime::CreationOptions options; options.m_ProfilingOptions.m_EnableProfiling = true; options.m_ProfilingOptions.m_TimelineEnabled = true; armnn::Runtime runtime(options); - armnnUtils::Sockets::Socket clientConnection; - clientConnection = armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC); - gatordmock::GatordMockService mockService(clientConnection, false); + auto basePipeServer = connectionHandler.GetNewBasePipeServer(false); + gatordmock::GatordMockService mockService(std::move(basePipeServer), false); // Read the stream metadata on the mock side. if (!mockService.WaitForStreamMetaData()) @@ -484,8 +469,6 @@ BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation) BOOST_CHECK(timelineDecoder.GetModel().m_Events.size() == 0); mockService.WaitForReceivingThread(); - armnnUtils::Sockets::Close(listeningSocket); - GetProfilingService(&runtime).Disconnect(); } -- cgit v1.2.1