From 25b7436b02514145a0289daff78f5b9f64cdd0db Mon Sep 17 00:00:00 2001 From: Rob Hughes Date: Mon, 13 Jan 2020 11:14:59 +0000 Subject: Add thin abstraction layer for network sockets This makes SocketProfilingConnection and GatordMock work on Windows as well as Linux Change-Id: I4b10c079b653a1c3f61eb20694e5b5f8a6f5fdfb Signed-off-by: Robert Hughes --- Android.mk | 1 + CMakeLists.txt | 7 ++ src/armnnUtils/NetworkSockets.cpp | 99 ++++++++++++++++++++++ src/armnnUtils/NetworkSockets.hpp | 59 +++++++++++++ src/profiling/SocketProfilingConnection.cpp | 25 +++--- src/profiling/SocketProfilingConnection.hpp | 4 +- tests/profiling/gatordmock/CommandFileParser.cpp | 4 +- tests/profiling/gatordmock/GatordMockService.cpp | 51 ++++++----- tests/profiling/gatordmock/GatordMockService.hpp | 11 +-- .../TimelineCaptureCommandHandler.cpp | 2 +- 10 files changed, 214 insertions(+), 49 deletions(-) create mode 100644 src/armnnUtils/NetworkSockets.cpp create mode 100644 src/armnnUtils/NetworkSockets.hpp diff --git a/Android.mk b/Android.mk index bfaee443a2..8f348d9ec5 100644 --- a/Android.mk +++ b/Android.mk @@ -117,6 +117,7 @@ LOCAL_SRC_FILES := \ src/armnnUtils/Permute.cpp \ src/armnnUtils/TensorUtils.cpp \ src/armnnUtils/VerificationHelpers.cpp \ + src/armnnUtils/NetworkSockets.cpp \ src/armnn/layers/AbsLayer.cpp \ src/armnn/layers/ActivationLayer.cpp \ src/armnn/layers/AdditionLayer.cpp \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d54137937..e39c2b8871 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,6 +65,8 @@ list(APPEND armnnUtils_sources src/armnnUtils/QuantizeHelper.hpp src/armnnUtils/TensorIOUtils.hpp src/armnnUtils/TensorUtils.cpp + src/armnnUtils/NetworkSockets.hpp + src/armnnUtils/NetworkSockets.cpp ) add_library_ex(armnnUtils STATIC ${armnnUtils_sources}) @@ -533,6 +535,9 @@ target_include_directories(armnn PRIVATE src/profiling) target_link_libraries(armnn armnnUtils) target_link_libraries(armnn ${CMAKE_DL_LIBS}) +if ("${CMAKE_SYSTEM_NAME}" STREQUAL Windows) + target_link_libraries(armnn Ws2_32.lib) +endif() install(TARGETS armnn LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -963,8 +968,10 @@ if(BUILD_GATORD_MOCK) include_directories(${Boost_INCLUDE_DIRS} tests/profiling/timelineDecoder) add_library_ex(gatordMockService STATIC ${gatord_mock_sources}) + target_include_directories(gatordMockService PRIVATE src/armnnUtils) add_executable_ex(GatordMock tests/profiling/gatordmock/GatordMockMain.cpp) + target_include_directories(GatordMock PRIVATE src/armnnUtils) target_link_libraries(GatordMock armnn diff --git a/src/armnnUtils/NetworkSockets.cpp b/src/armnnUtils/NetworkSockets.cpp new file mode 100644 index 0000000000..cc28a90c48 --- /dev/null +++ b/src/armnnUtils/NetworkSockets.cpp @@ -0,0 +1,99 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "NetworkSockets.hpp" + +#if defined(__unix__) +#include +#include +#endif + +namespace armnnUtils +{ +namespace Sockets +{ + +bool Initialize() +{ +#if defined(__unix__) + return true; +#elif defined(_MSC_VER) + WSADATA wsaData; + return WSAStartup(MAKEWORD(2, 2), &wsaData) == 0; +#endif +} + +int Close(Socket s) +{ +#if defined(__unix__) + return close(s); +#elif defined(_MSC_VER) + return closesocket(s); +#endif +} + + +bool SetNonBlocking(Socket s) +{ +#if defined(__unix__) + const int currentFlags = fcntl(s, F_GETFL); + return fcntl(s, F_SETFL, currentFlags | O_NONBLOCK) == 0; +#elif defined(_MSC_VER) + u_long mode = 1; + return ioctlsocket(s, FIONBIO, &mode) == 0; +#endif +} + + +long Write(Socket s, const void* buf, size_t len) +{ +#if defined(__unix__) + return write(s, buf, len); +#elif defined(_MSC_VER) + return send(s, static_cast(buf), len, 0); +#endif +} + + +long Read(Socket s, void* buf, size_t len) +{ +#if defined(__unix__) + return read(s, buf, len); +#elif defined(_MSC_VER) + return recv(s, static_cast(buf), len, 0); +#endif +} + +int Ioctl(Socket s, unsigned long cmd, void* arg) +{ +#if defined(__unix__) + return ioctl(s, cmd, arg); +#elif defined(_MSC_VER) + return ioctlsocket(s, cmd, static_cast(arg)); +#endif +} + + +int Poll(PollFd* fds, size_t numFds, int timeout) +{ +#if defined(__unix__) + return poll(fds, numFds, timeout); +#elif defined(_MSC_VER) + return WSAPoll(fds, numFds, timeout); +#endif +} + + +armnnUtils::Sockets::Socket Accept(Socket s, sockaddr* addr, unsigned int* addrlen, int flags) +{ +#if defined(__unix__) + return accept4(s, addr, addrlen, flags); +#elif defined(_MSC_VER) + return accept(s, addr, reinterpret_cast(addrlen)); +#endif +} + +} +} diff --git a/src/armnnUtils/NetworkSockets.hpp b/src/armnnUtils/NetworkSockets.hpp new file mode 100644 index 0000000000..9e4770793c --- /dev/null +++ b/src/armnnUtils/NetworkSockets.hpp @@ -0,0 +1,59 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +// This file (along with its corresponding .cpp) defines a very thin platform abstraction layer for the use of +// networking sockets. Thankfully the underlying APIs on Windows and Linux are very similar so not much conversion +// is needed (typically just forwarding the parameters to a differently named function). +// Some of the APIs are in fact completely identical and so no forwarding function is needed. + +#pragma once + +#if defined(__unix__) +#include +#include +#include +#include +#elif defined(_MSC_VER) +#include +#include +#endif + +namespace armnnUtils +{ +namespace Sockets +{ + +#if defined(__unix__) + +using Socket = int; +using PollFd = pollfd; + +#elif defined(_MSC_VER) + +using Socket = SOCKET; +using PollFd = WSAPOLLFD; +#define SOCK_CLOEXEC 0 + +#endif + +/// Performs any required one-time setup. +bool Initialize(); + +int Close(Socket s); + +bool SetNonBlocking(Socket s); + +long Write(Socket s, const void* buf, size_t len); + +long Read(Socket s, void* buf, size_t len); + +int Ioctl(Socket s, unsigned long cmd, void* arg); + +int Poll(PollFd* fds, size_t numFds, int timeout); + +Socket Accept(Socket s, sockaddr* addr, unsigned int* addrlen, int flags); + +} +} diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp index c78c182412..4bbbc2962c 100644 --- a/src/profiling/SocketProfilingConnection.cpp +++ b/src/profiling/SocketProfilingConnection.cpp @@ -7,11 +7,10 @@ #include #include -#include -#include -#include #include +using namespace armnnUtils; + namespace armnn { namespace profiling @@ -19,6 +18,7 @@ namespace profiling SocketProfilingConnection::SocketProfilingConnection() { + Sockets::Initialize(); memset(m_Socket, 0, sizeof(m_Socket)); // Note: we're using Linux specific SOCK_CLOEXEC flag. m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); @@ -28,7 +28,7 @@ SocketProfilingConnection::SocketProfilingConnection() } // Connect to the named unix domain socket. - struct sockaddr_un server{}; + sockaddr_un server{}; memset(&server, 0, sizeof(sockaddr_un)); // As m_GatorNamespace begins with a null character we need to ignore that when getting its length. memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1); @@ -43,8 +43,7 @@ SocketProfilingConnection::SocketProfilingConnection() m_Socket[0].events = POLLIN; // Make the socket non blocking. - const int currentFlags = fcntl(m_Socket[0].fd, F_GETFL); - if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK)) + if (!Sockets::SetNonBlocking(m_Socket[0].fd)) { Close(); throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno)); @@ -58,7 +57,7 @@ bool SocketProfilingConnection::IsOpen() const void SocketProfilingConnection::Close() { - if (close(m_Socket[0].fd) != 0) + if (Sockets::Close(m_Socket[0].fd) != 0) { throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno)); } @@ -73,14 +72,14 @@ bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_ return false; } - return write(m_Socket[0].fd, buffer, length) != -1; + return Sockets::Write(m_Socket[0].fd, buffer, length) != -1; } Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) { // Is there currently at least a header worth of data waiting to be read? int bytes_available = 0; - ioctl(m_Socket[0].fd, FIONREAD, &bytes_available); + Sockets::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available); if (bytes_available >= 8) { // Yes there is. Read it: @@ -88,7 +87,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) } // Poll for data on the socket or until timeout occurs - int pollResult = poll(m_Socket, 1, static_cast(timeout)); + int pollResult = Sockets::Poll(&m_Socket[0], 1, static_cast(timeout)); switch (pollResult) { @@ -136,7 +135,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) Packet SocketProfilingConnection::ReceivePacket() { char header[8] = {}; - ssize_t receiveResult = recv(m_Socket[0].fd, &header, sizeof(header), 0); + long receiveResult = Sockets::Read(m_Socket[0].fd, &header, sizeof(header)); // We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error. switch( receiveResult ) { @@ -168,10 +167,10 @@ Packet SocketProfilingConnection::ReceivePacket() if (dataLength > 0) { packetData = std::make_unique(dataLength); - ssize_t receivedLength = recv(m_Socket[0].fd, packetData.get(), dataLength, 0); + long receivedLength = Sockets::Read(m_Socket[0].fd, packetData.get(), dataLength); if (receivedLength < 0) { - throw armnn::RuntimeException(std::string("Error occured on recv: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Error occurred on recv: ") + strerror(errno)); } if (dataLength != static_cast(receivedLength)) { diff --git a/src/profiling/SocketProfilingConnection.hpp b/src/profiling/SocketProfilingConnection.hpp index 5fb02bb19e..05c7130de7 100644 --- a/src/profiling/SocketProfilingConnection.hpp +++ b/src/profiling/SocketProfilingConnection.hpp @@ -5,8 +5,8 @@ #include "IProfilingConnection.hpp" -#include #include +#include #pragma once @@ -31,7 +31,7 @@ private: // To indicate we want to use an abstract UDS ensure the first character of the address is 0. const char* m_GatorNamespace = "\0gatord_namespace"; - struct pollfd m_Socket[1]{}; + armnnUtils::Sockets::PollFd m_Socket[1]{}; }; } // namespace profiling diff --git a/tests/profiling/gatordmock/CommandFileParser.cpp b/tests/profiling/gatordmock/CommandFileParser.cpp index 4a8a19b5d2..7c746f16e9 100644 --- a/tests/profiling/gatordmock/CommandFileParser.cpp +++ b/tests/profiling/gatordmock/CommandFileParser.cpp @@ -54,7 +54,7 @@ void CommandFileParser::ParseFile(std::string CommandFile, GatordMockService& mo // 500000 polling period in micro seconds // 1 2 5 10 counter list - uint period = static_cast(std::stoul(tokens[1])); + uint32_t period = static_cast(std::stoul(tokens[1])); std::vector counters; @@ -73,7 +73,7 @@ void CommandFileParser::ParseFile(std::string CommandFile, GatordMockService& mo // WAIT command // 11000000 timeout period in micro seconds - uint timeout = static_cast(std::stoul(tokens[1])); + uint32_t timeout = static_cast(std::stoul(tokens[1])); mockService.WaitCommand(timeout); } diff --git a/tests/profiling/gatordmock/GatordMockService.cpp b/tests/profiling/gatordmock/GatordMockService.cpp index 529ef063dd..c5211962d3 100644 --- a/tests/profiling/gatordmock/GatordMockService.cpp +++ b/tests/profiling/gatordmock/GatordMockService.cpp @@ -8,17 +8,15 @@ #include #include #include +#include #include #include #include #include -#include #include -#include -#include -#include -#include + +using namespace armnnUtils; namespace armnn { @@ -28,6 +26,7 @@ namespace gatordmock bool GatordMockService::OpenListeningSocket(std::string udsNamespace) { + Sockets::Initialize(); m_ListeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); if (-1 == m_ListeningSocket) { @@ -56,9 +55,9 @@ bool GatordMockService::OpenListeningSocket(std::string udsNamespace) return true; } -int GatordMockService::BlockForOneClient() +Sockets::Socket GatordMockService::BlockForOneClient() { - m_ClientConnection = accept4(m_ListeningSocket, nullptr, nullptr, SOCK_CLOEXEC); + m_ClientConnection = Sockets::Accept(m_ListeningSocket, nullptr, nullptr, SOCK_CLOEXEC); if (-1 == m_ClientConnection) { std::cerr << ": Failure when waiting for a client connection: " << strerror(errno) << std::endl; @@ -112,13 +111,14 @@ bool GatordMockService::WaitForStreamMetaData() // Remember we already read the pipe magic 4 bytes. uint32_t metaDataLength = ToUint32(&header[4], m_Endianness) - 4; // Read the entire packet. - uint8_t packetData[metaDataLength]; - if (metaDataLength != boost::numeric_cast(read(m_ClientConnection, &packetData, metaDataLength))) + std::vector packetData(metaDataLength); + if (metaDataLength != + boost::numeric_cast(Sockets::Read(m_ClientConnection, packetData.data(), metaDataLength))) { std::cerr << ": Protocol read error. Data length mismatch." << std::endl; return false; } - EchoPacket(PacketDirection::ReceivedData, packetData, metaDataLength); + EchoPacket(PacketDirection::ReceivedData, packetData.data(), metaDataLength); m_StreamMetaDataVersion = ToUint32(&packetData[0], m_Endianness); m_StreamMetaDataMaxDataLen = ToUint32(&packetData[4], m_Endianness); m_StreamMetaDataPid = ToUint32(&packetData[8], m_Endianness); @@ -153,10 +153,9 @@ bool GatordMockService::LaunchReceivingThread() std::cout << "Launching receiving thread." << std::endl; } // At this point we want to make the socket non blocking. - const int currentFlags = fcntl(m_ClientConnection, F_GETFL); - if (0 != fcntl(m_ClientConnection, F_SETFL, currentFlags | O_NONBLOCK)) + if (!Sockets::SetNonBlocking(m_ClientConnection)) { - close(m_ClientConnection); + Sockets::Close(m_ClientConnection); std::cerr << "Failed to set socket as non blocking: " << strerror(errno) << std::endl; return false; } @@ -212,13 +211,13 @@ void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::v // should deal with it. } -void GatordMockService::WaitCommand(uint timeout) +void GatordMockService::WaitCommand(uint32_t timeout) { // Wait for a maximum of timeout microseconds or if the receive thread has closed. // There is a certain level of rounding involved in this timing. - uint iterations = timeout / 1000; + uint32_t iterations = timeout / 1000; std::cout << std::dec << "Wait command with timeout of " << timeout << " iterations = " << iterations << std::endl; - uint count = 0; + uint32_t count = 0; while ((this->ReceiveThreadRunning() && (count < iterations))) { std::this_thread::sleep_for(std::chrono::microseconds(1000)); @@ -261,7 +260,7 @@ armnn::profiling::Packet GatordMockService::WaitForPacket(uint32_t timeoutMs) { // Is there currently more than a headers worth of data waiting to be read? int bytes_available; - ioctl(m_ClientConnection, FIONREAD, &bytes_available); + Sockets::Ioctl(m_ClientConnection, FIONREAD, &bytes_available); if (bytes_available > 8) { // Yes there is. Read it: @@ -272,7 +271,7 @@ armnn::profiling::Packet GatordMockService::WaitForPacket(uint32_t timeoutMs) // No there's not. Poll for more data. struct pollfd pollingFd[1]{}; pollingFd[0].fd = m_ClientConnection; - int pollResult = poll(pollingFd, 1, static_cast(timeoutMs)); + int pollResult = Sockets::Poll(pollingFd, 1, static_cast(timeoutMs)); switch (pollResult) { @@ -362,16 +361,16 @@ bool GatordMockService::SendPacket(uint32_t packetFamily, uint32_t packetId, con header[0] = packetFamily << 26 | packetId << 16; header[1] = dataLength; // Add the header to the packet. - uint8_t packet[8 + dataLength]; - InsertU32(header[0], packet, m_Endianness); - InsertU32(header[1], packet + 4, m_Endianness); + std::vector packet(8 + dataLength); + InsertU32(header[0], packet.data(), m_Endianness); + InsertU32(header[1], packet.data() + 4, m_Endianness); // And the rest of the data if there is any. if (dataLength > 0) { - memcpy((packet + 8), data, dataLength); + memcpy((packet.data() + 8), data, dataLength); } - EchoPacket(PacketDirection::Sending, packet, sizeof(packet)); - if (-1 == write(m_ClientConnection, packet, sizeof(packet))) + EchoPacket(PacketDirection::Sending, packet.data(), packet.size()); + if (-1 == Sockets::Write(m_ClientConnection, packet.data(), packet.size())) { std::cerr << ": Failure when writing to client socket: " << strerror(errno) << std::endl; return false; @@ -396,10 +395,10 @@ bool GatordMockService::ReadHeader(uint32_t headerAsWords[2]) bool GatordMockService::ReadFromSocket(uint8_t* packetData, uint32_t expectedLength) { // This is a blocking read until either expectedLength has been received or an error is detected. - ssize_t totalBytesRead = 0; + long totalBytesRead = 0; while (boost::numeric_cast(totalBytesRead) < expectedLength) { - ssize_t bytesRead = recv(m_ClientConnection, packetData, expectedLength, 0); + long bytesRead = Sockets::Read(m_ClientConnection, packetData, expectedLength); if (bytesRead < 0) { std::cerr << ": Failure when reading from client socket: " << strerror(errno) << std::endl; diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp index c3afc333ca..f91e902db8 100644 --- a/tests/profiling/gatordmock/GatordMockService.hpp +++ b/tests/profiling/gatordmock/GatordMockService.hpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -49,8 +50,8 @@ public: ~GatordMockService() { // We have set SOCK_CLOEXEC on these sockets but we'll close them to be good citizens. - close(m_ClientConnection); - close(m_ListeningSocket); + armnnUtils::Sockets::Close(m_ClientConnection); + armnnUtils::Sockets::Close(m_ListeningSocket); } /// Establish the Unix domain socket and set it to listen for connections. @@ -60,7 +61,7 @@ public: /// Block waiting to accept one client to connect to the UDS. /// @return the file descriptor of the client connection. - int BlockForOneClient(); + armnnUtils::Sockets::Socket BlockForOneClient(); /// Once the connection is open wait to receive the stream meta data packet from the client. Reading this /// packet differs from others as we need to determine endianness. @@ -147,8 +148,8 @@ private: armnn::profiling::CommandHandlerRegistry& m_HandlerRegistry; bool m_EchoPackets; - int m_ListeningSocket; - int m_ClientConnection; + armnnUtils::Sockets::Socket m_ListeningSocket; + armnnUtils::Sockets::Socket m_ClientConnection; std::thread m_ListeningThread; std::atomic m_CloseReceivingThread; }; diff --git a/tests/profiling/timelineDecoder/TimelineCaptureCommandHandler.cpp b/tests/profiling/timelineDecoder/TimelineCaptureCommandHandler.cpp index bdceca69b0..78b1300ed3 100644 --- a/tests/profiling/timelineDecoder/TimelineCaptureCommandHandler.cpp +++ b/tests/profiling/timelineDecoder/TimelineCaptureCommandHandler.cpp @@ -122,7 +122,7 @@ void TimelineCaptureCommandHandler::ReadEvent(const unsigned char* data, uint32_ event.m_TimeStamp = profiling::ReadUint64(data, offset); offset += uint64_t_size; - event.m_ThreadId = new u_int8_t[threadId_size]; + event.m_ThreadId = new uint8_t[threadId_size]; profiling::ReadBytes(data, offset, threadId_size, event.m_ThreadId); offset += threadId_size; -- cgit v1.2.1