aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/armnnUtils/NetworkSockets.cpp99
-rw-r--r--src/armnnUtils/NetworkSockets.hpp59
-rw-r--r--src/profiling/SocketProfilingConnection.cpp25
-rw-r--r--src/profiling/SocketProfilingConnection.hpp4
4 files changed, 172 insertions, 15 deletions
diff --git a/src/armnnUtils/NetworkSockets.cpp b/src/armnnUtils/NetworkSockets.cpp
new file mode 100644
index 0000000000..cc28a90c48
--- /dev/null
+++ b/src/armnnUtils/NetworkSockets.cpp
@@ -0,0 +1,99 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "NetworkSockets.hpp"
+
+#if defined(__unix__)
+#include <unistd.h>
+#include <fcntl.h>
+#endif
+
+namespace armnnUtils
+{
+namespace Sockets
+{
+
+bool Initialize()
+{
+#if defined(__unix__)
+ return true;
+#elif defined(_MSC_VER)
+ WSADATA wsaData;
+ return WSAStartup(MAKEWORD(2, 2), &wsaData) == 0;
+#endif
+}
+
+int Close(Socket s)
+{
+#if defined(__unix__)
+ return close(s);
+#elif defined(_MSC_VER)
+ return closesocket(s);
+#endif
+}
+
+
+bool SetNonBlocking(Socket s)
+{
+#if defined(__unix__)
+ const int currentFlags = fcntl(s, F_GETFL);
+ return fcntl(s, F_SETFL, currentFlags | O_NONBLOCK) == 0;
+#elif defined(_MSC_VER)
+ u_long mode = 1;
+ return ioctlsocket(s, FIONBIO, &mode) == 0;
+#endif
+}
+
+
+long Write(Socket s, const void* buf, size_t len)
+{
+#if defined(__unix__)
+ return write(s, buf, len);
+#elif defined(_MSC_VER)
+ return send(s, static_cast<const char*>(buf), len, 0);
+#endif
+}
+
+
+long Read(Socket s, void* buf, size_t len)
+{
+#if defined(__unix__)
+ return read(s, buf, len);
+#elif defined(_MSC_VER)
+ return recv(s, static_cast<char*>(buf), len, 0);
+#endif
+}
+
+int Ioctl(Socket s, unsigned long cmd, void* arg)
+{
+#if defined(__unix__)
+ return ioctl(s, cmd, arg);
+#elif defined(_MSC_VER)
+ return ioctlsocket(s, cmd, static_cast<u_long*>(arg));
+#endif
+}
+
+
+int Poll(PollFd* fds, size_t numFds, int timeout)
+{
+#if defined(__unix__)
+ return poll(fds, numFds, timeout);
+#elif defined(_MSC_VER)
+ return WSAPoll(fds, numFds, timeout);
+#endif
+}
+
+
+armnnUtils::Sockets::Socket Accept(Socket s, sockaddr* addr, unsigned int* addrlen, int flags)
+{
+#if defined(__unix__)
+ return accept4(s, addr, addrlen, flags);
+#elif defined(_MSC_VER)
+ return accept(s, addr, reinterpret_cast<int*>(addrlen));
+#endif
+}
+
+}
+}
diff --git a/src/armnnUtils/NetworkSockets.hpp b/src/armnnUtils/NetworkSockets.hpp
new file mode 100644
index 0000000000..9e4770793c
--- /dev/null
+++ b/src/armnnUtils/NetworkSockets.hpp
@@ -0,0 +1,59 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+// This file (along with its corresponding .cpp) defines a very thin platform abstraction layer for the use of
+// networking sockets. Thankfully the underlying APIs on Windows and Linux are very similar so not much conversion
+// is needed (typically just forwarding the parameters to a differently named function).
+// Some of the APIs are in fact completely identical and so no forwarding function is needed.
+
+#pragma once
+
+#if defined(__unix__)
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#elif defined(_MSC_VER)
+#include <winsock2.h>
+#include <afunix.h>
+#endif
+
+namespace armnnUtils
+{
+namespace Sockets
+{
+
+#if defined(__unix__)
+
+using Socket = int;
+using PollFd = pollfd;
+
+#elif defined(_MSC_VER)
+
+using Socket = SOCKET;
+using PollFd = WSAPOLLFD;
+#define SOCK_CLOEXEC 0
+
+#endif
+
+/// Performs any required one-time setup.
+bool Initialize();
+
+int Close(Socket s);
+
+bool SetNonBlocking(Socket s);
+
+long Write(Socket s, const void* buf, size_t len);
+
+long Read(Socket s, void* buf, size_t len);
+
+int Ioctl(Socket s, unsigned long cmd, void* arg);
+
+int Poll(PollFd* fds, size_t numFds, int timeout);
+
+Socket Accept(Socket s, sockaddr* addr, unsigned int* addrlen, int flags);
+
+}
+}
diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp
index c78c182412..4bbbc2962c 100644
--- a/src/profiling/SocketProfilingConnection.cpp
+++ b/src/profiling/SocketProfilingConnection.cpp
@@ -7,11 +7,10 @@
#include <cerrno>
#include <fcntl.h>
-#include <sys/ioctl.h>
-#include <sys/socket.h>
-#include <sys/un.h>
#include <string>
+using namespace armnnUtils;
+
namespace armnn
{
namespace profiling
@@ -19,6 +18,7 @@ namespace profiling
SocketProfilingConnection::SocketProfilingConnection()
{
+ Sockets::Initialize();
memset(m_Socket, 0, sizeof(m_Socket));
// Note: we're using Linux specific SOCK_CLOEXEC flag.
m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
@@ -28,7 +28,7 @@ SocketProfilingConnection::SocketProfilingConnection()
}
// Connect to the named unix domain socket.
- struct sockaddr_un server{};
+ sockaddr_un server{};
memset(&server, 0, sizeof(sockaddr_un));
// As m_GatorNamespace begins with a null character we need to ignore that when getting its length.
memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1);
@@ -43,8 +43,7 @@ SocketProfilingConnection::SocketProfilingConnection()
m_Socket[0].events = POLLIN;
// Make the socket non blocking.
- const int currentFlags = fcntl(m_Socket[0].fd, F_GETFL);
- if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK))
+ if (!Sockets::SetNonBlocking(m_Socket[0].fd))
{
Close();
throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno));
@@ -58,7 +57,7 @@ bool SocketProfilingConnection::IsOpen() const
void SocketProfilingConnection::Close()
{
- if (close(m_Socket[0].fd) != 0)
+ if (Sockets::Close(m_Socket[0].fd) != 0)
{
throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno));
}
@@ -73,14 +72,14 @@ bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_
return false;
}
- return write(m_Socket[0].fd, buffer, length) != -1;
+ return Sockets::Write(m_Socket[0].fd, buffer, length) != -1;
}
Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
{
// Is there currently at least a header worth of data waiting to be read?
int bytes_available = 0;
- ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
+ Sockets::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
if (bytes_available >= 8)
{
// Yes there is. Read it:
@@ -88,7 +87,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
}
// Poll for data on the socket or until timeout occurs
- int pollResult = poll(m_Socket, 1, static_cast<int>(timeout));
+ int pollResult = Sockets::Poll(&m_Socket[0], 1, static_cast<int>(timeout));
switch (pollResult)
{
@@ -136,7 +135,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
Packet SocketProfilingConnection::ReceivePacket()
{
char header[8] = {};
- ssize_t receiveResult = recv(m_Socket[0].fd, &header, sizeof(header), 0);
+ long receiveResult = Sockets::Read(m_Socket[0].fd, &header, sizeof(header));
// We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error.
switch( receiveResult )
{
@@ -168,10 +167,10 @@ Packet SocketProfilingConnection::ReceivePacket()
if (dataLength > 0)
{
packetData = std::make_unique<unsigned char[]>(dataLength);
- ssize_t receivedLength = recv(m_Socket[0].fd, packetData.get(), dataLength, 0);
+ long receivedLength = Sockets::Read(m_Socket[0].fd, packetData.get(), dataLength);
if (receivedLength < 0)
{
- throw armnn::RuntimeException(std::string("Error occured on recv: ") + strerror(errno));
+ throw armnn::RuntimeException(std::string("Error occurred on recv: ") + strerror(errno));
}
if (dataLength != static_cast<uint32_t>(receivedLength))
{
diff --git a/src/profiling/SocketProfilingConnection.hpp b/src/profiling/SocketProfilingConnection.hpp
index 5fb02bb19e..05c7130de7 100644
--- a/src/profiling/SocketProfilingConnection.hpp
+++ b/src/profiling/SocketProfilingConnection.hpp
@@ -5,8 +5,8 @@
#include "IProfilingConnection.hpp"
-#include <poll.h>
#include <Runtime.hpp>
+#include <NetworkSockets.hpp>
#pragma once
@@ -31,7 +31,7 @@ private:
// To indicate we want to use an abstract UDS ensure the first character of the address is 0.
const char* m_GatorNamespace = "\0gatord_namespace";
- struct pollfd m_Socket[1]{};
+ armnnUtils::Sockets::PollFd m_Socket[1]{};
};
} // namespace profiling