diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnnUtils/NetworkSockets.cpp | 99 | ||||
-rw-r--r-- | src/armnnUtils/NetworkSockets.hpp | 59 | ||||
-rw-r--r-- | src/profiling/SocketProfilingConnection.cpp | 25 | ||||
-rw-r--r-- | src/profiling/SocketProfilingConnection.hpp | 4 |
4 files changed, 172 insertions, 15 deletions
diff --git a/src/armnnUtils/NetworkSockets.cpp b/src/armnnUtils/NetworkSockets.cpp new file mode 100644 index 0000000000..cc28a90c48 --- /dev/null +++ b/src/armnnUtils/NetworkSockets.cpp @@ -0,0 +1,99 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "NetworkSockets.hpp" + +#if defined(__unix__) +#include <unistd.h> +#include <fcntl.h> +#endif + +namespace armnnUtils +{ +namespace Sockets +{ + +bool Initialize() +{ +#if defined(__unix__) + return true; +#elif defined(_MSC_VER) + WSADATA wsaData; + return WSAStartup(MAKEWORD(2, 2), &wsaData) == 0; +#endif +} + +int Close(Socket s) +{ +#if defined(__unix__) + return close(s); +#elif defined(_MSC_VER) + return closesocket(s); +#endif +} + + +bool SetNonBlocking(Socket s) +{ +#if defined(__unix__) + const int currentFlags = fcntl(s, F_GETFL); + return fcntl(s, F_SETFL, currentFlags | O_NONBLOCK) == 0; +#elif defined(_MSC_VER) + u_long mode = 1; + return ioctlsocket(s, FIONBIO, &mode) == 0; +#endif +} + + +long Write(Socket s, const void* buf, size_t len) +{ +#if defined(__unix__) + return write(s, buf, len); +#elif defined(_MSC_VER) + return send(s, static_cast<const char*>(buf), len, 0); +#endif +} + + +long Read(Socket s, void* buf, size_t len) +{ +#if defined(__unix__) + return read(s, buf, len); +#elif defined(_MSC_VER) + return recv(s, static_cast<char*>(buf), len, 0); +#endif +} + +int Ioctl(Socket s, unsigned long cmd, void* arg) +{ +#if defined(__unix__) + return ioctl(s, cmd, arg); +#elif defined(_MSC_VER) + return ioctlsocket(s, cmd, static_cast<u_long*>(arg)); +#endif +} + + +int Poll(PollFd* fds, size_t numFds, int timeout) +{ +#if defined(__unix__) + return poll(fds, numFds, timeout); +#elif defined(_MSC_VER) + return WSAPoll(fds, numFds, timeout); +#endif +} + + +armnnUtils::Sockets::Socket Accept(Socket s, sockaddr* addr, unsigned int* addrlen, int flags) +{ +#if defined(__unix__) + return accept4(s, addr, addrlen, flags); +#elif defined(_MSC_VER) + return accept(s, addr, reinterpret_cast<int*>(addrlen)); +#endif +} + +} +} diff --git a/src/armnnUtils/NetworkSockets.hpp b/src/armnnUtils/NetworkSockets.hpp new file mode 100644 index 0000000000..9e4770793c --- /dev/null +++ b/src/armnnUtils/NetworkSockets.hpp @@ -0,0 +1,59 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +// This file (along with its corresponding .cpp) defines a very thin platform abstraction layer for the use of +// networking sockets. Thankfully the underlying APIs on Windows and Linux are very similar so not much conversion +// is needed (typically just forwarding the parameters to a differently named function). +// Some of the APIs are in fact completely identical and so no forwarding function is needed. + +#pragma once + +#if defined(__unix__) +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/un.h> +#elif defined(_MSC_VER) +#include <winsock2.h> +#include <afunix.h> +#endif + +namespace armnnUtils +{ +namespace Sockets +{ + +#if defined(__unix__) + +using Socket = int; +using PollFd = pollfd; + +#elif defined(_MSC_VER) + +using Socket = SOCKET; +using PollFd = WSAPOLLFD; +#define SOCK_CLOEXEC 0 + +#endif + +/// Performs any required one-time setup. +bool Initialize(); + +int Close(Socket s); + +bool SetNonBlocking(Socket s); + +long Write(Socket s, const void* buf, size_t len); + +long Read(Socket s, void* buf, size_t len); + +int Ioctl(Socket s, unsigned long cmd, void* arg); + +int Poll(PollFd* fds, size_t numFds, int timeout); + +Socket Accept(Socket s, sockaddr* addr, unsigned int* addrlen, int flags); + +} +} diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp index c78c182412..4bbbc2962c 100644 --- a/src/profiling/SocketProfilingConnection.cpp +++ b/src/profiling/SocketProfilingConnection.cpp @@ -7,11 +7,10 @@ #include <cerrno> #include <fcntl.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/un.h> #include <string> +using namespace armnnUtils; + namespace armnn { namespace profiling @@ -19,6 +18,7 @@ namespace profiling SocketProfilingConnection::SocketProfilingConnection() { + Sockets::Initialize(); memset(m_Socket, 0, sizeof(m_Socket)); // Note: we're using Linux specific SOCK_CLOEXEC flag. m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); @@ -28,7 +28,7 @@ SocketProfilingConnection::SocketProfilingConnection() } // Connect to the named unix domain socket. - struct sockaddr_un server{}; + sockaddr_un server{}; memset(&server, 0, sizeof(sockaddr_un)); // As m_GatorNamespace begins with a null character we need to ignore that when getting its length. memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1); @@ -43,8 +43,7 @@ SocketProfilingConnection::SocketProfilingConnection() m_Socket[0].events = POLLIN; // Make the socket non blocking. - const int currentFlags = fcntl(m_Socket[0].fd, F_GETFL); - if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK)) + if (!Sockets::SetNonBlocking(m_Socket[0].fd)) { Close(); throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno)); @@ -58,7 +57,7 @@ bool SocketProfilingConnection::IsOpen() const void SocketProfilingConnection::Close() { - if (close(m_Socket[0].fd) != 0) + if (Sockets::Close(m_Socket[0].fd) != 0) { throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno)); } @@ -73,14 +72,14 @@ bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_ return false; } - return write(m_Socket[0].fd, buffer, length) != -1; + return Sockets::Write(m_Socket[0].fd, buffer, length) != -1; } Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) { // Is there currently at least a header worth of data waiting to be read? int bytes_available = 0; - ioctl(m_Socket[0].fd, FIONREAD, &bytes_available); + Sockets::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available); if (bytes_available >= 8) { // Yes there is. Read it: @@ -88,7 +87,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) } // Poll for data on the socket or until timeout occurs - int pollResult = poll(m_Socket, 1, static_cast<int>(timeout)); + int pollResult = Sockets::Poll(&m_Socket[0], 1, static_cast<int>(timeout)); switch (pollResult) { @@ -136,7 +135,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) Packet SocketProfilingConnection::ReceivePacket() { char header[8] = {}; - ssize_t receiveResult = recv(m_Socket[0].fd, &header, sizeof(header), 0); + long receiveResult = Sockets::Read(m_Socket[0].fd, &header, sizeof(header)); // We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error. switch( receiveResult ) { @@ -168,10 +167,10 @@ Packet SocketProfilingConnection::ReceivePacket() if (dataLength > 0) { packetData = std::make_unique<unsigned char[]>(dataLength); - ssize_t receivedLength = recv(m_Socket[0].fd, packetData.get(), dataLength, 0); + long receivedLength = Sockets::Read(m_Socket[0].fd, packetData.get(), dataLength); if (receivedLength < 0) { - throw armnn::RuntimeException(std::string("Error occured on recv: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Error occurred on recv: ") + strerror(errno)); } if (dataLength != static_cast<uint32_t>(receivedLength)) { diff --git a/src/profiling/SocketProfilingConnection.hpp b/src/profiling/SocketProfilingConnection.hpp index 5fb02bb19e..05c7130de7 100644 --- a/src/profiling/SocketProfilingConnection.hpp +++ b/src/profiling/SocketProfilingConnection.hpp @@ -5,8 +5,8 @@ #include "IProfilingConnection.hpp" -#include <poll.h> #include <Runtime.hpp> +#include <NetworkSockets.hpp> #pragma once @@ -31,7 +31,7 @@ private: // To indicate we want to use an abstract UDS ensure the first character of the address is 0. const char* m_GatorNamespace = "\0gatord_namespace"; - struct pollfd m_Socket[1]{}; + armnnUtils::Sockets::PollFd m_Socket[1]{}; }; } // namespace profiling |