ArmNN  NotReleased
SocketProfilingConnection.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
8 #include <cerrno>
9 #include <fcntl.h>
10 #include <string>
11 
12 using namespace armnnUtils;
13 
14 namespace armnn
15 {
16 namespace profiling
17 {
18 
20 {
22  memset(m_Socket, 0, sizeof(m_Socket));
23  // Note: we're using Linux specific SOCK_CLOEXEC flag.
24  m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
25  if (m_Socket[0].fd == -1)
26  {
27  throw armnn::RuntimeException(std::string("Socket construction failed: ") + strerror(errno));
28  }
29 
30  // Connect to the named unix domain socket.
31  sockaddr_un server{};
32  memset(&server, 0, sizeof(sockaddr_un));
33  // As m_GatorNamespace begins with a null character we need to ignore that when getting its length.
34  memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1);
35  server.sun_family = AF_UNIX;
36  if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un)))
37  {
38  Close();
39  throw armnn::RuntimeException(std::string("Cannot connect to stream socket: ") + strerror(errno));
40  }
41 
42  // Our socket will only be interested in polling reads.
43  m_Socket[0].events = POLLIN;
44 
45  // Make the socket non blocking.
46  if (!Sockets::SetNonBlocking(m_Socket[0].fd))
47  {
48  Close();
49  throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno));
50  }
51 }
52 
54 {
55  return m_Socket[0].fd > 0;
56 }
57 
59 {
60  if (Sockets::Close(m_Socket[0].fd) != 0)
61  {
62  throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno));
63  }
64 
65  memset(m_Socket, 0, sizeof(m_Socket));
66 }
67 
68 bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length)
69 {
70  if (buffer == nullptr || length == 0)
71  {
72  return false;
73  }
74 
75  return Sockets::Write(m_Socket[0].fd, buffer, length) != -1;
76 }
77 
79 {
80  // Is there currently at least a header worth of data waiting to be read?
81  int bytes_available = 0;
82  Sockets::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
83  if (bytes_available >= 8)
84  {
85  // Yes there is. Read it:
86  return ReceivePacket();
87  }
88 
89  // Poll for data on the socket or until timeout occurs
90  int pollResult = Sockets::Poll(&m_Socket[0], 1, static_cast<int>(timeout));
91 
92  switch (pollResult)
93  {
94  case -1: // Error
95  throw armnn::RuntimeException(std::string("Read failure from socket: ") + strerror(errno));
96 
97  case 0: // Timeout
98  throw TimeoutException("Timeout while reading from socket");
99 
100  default: // Normal poll return but it could still contain an error signal
101  // Check if the socket reported an error
102  if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
103  {
104  if (m_Socket[0].revents == POLLNVAL)
105  {
106  // This is an unrecoverable error.
107  Close();
108  throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLNVAL"));
109  }
110  if (m_Socket[0].revents == POLLERR)
111  {
112  throw armnn::RuntimeException(std::string("Error while polling receiving socket: POLLERR: ") +
113  strerror(errno));
114  }
115  if (m_Socket[0].revents == POLLHUP)
116  {
117  // This is an unrecoverable error.
118  Close();
119  throw armnn::RuntimeException(std::string("Connection closed by remote client: POLLHUP"));
120  }
121  }
122 
123  // Check if there is data to read
124  if (!(m_Socket[0].revents & (POLLIN)))
125  {
126  // This is a corner case. The socket as been woken up but not with any data.
127  // We'll throw a timeout exception to loop around again.
128  throw armnn::TimeoutException("File descriptor was polled but no data was available to receive.");
129  }
130 
131  return ReceivePacket();
132  }
133 }
134 
135 Packet SocketProfilingConnection::ReceivePacket()
136 {
137  char header[8] = {};
138  long receiveResult = Sockets::Read(m_Socket[0].fd, &header, sizeof(header));
139  // We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error.
140  switch( receiveResult )
141  {
142  case 0:
143  // Socket has closed.
144  Close();
145  throw armnn::RuntimeException("Remote socket has closed the connection.");
146  case -1:
147  // There's been a socket error. We will presume it's unrecoverable.
148  Close();
149  throw armnn::RuntimeException(std::string("Error occured on recv: ") + strerror(errno));
150  default:
151  if (receiveResult < 8)
152  {
153  throw armnn::RuntimeException("The received packet did not contains a valid MIPE header");
154  }
155  break;
156  }
157 
158  // stream_metadata_identifier is the first 4 bytes
159  uint32_t metadataIdentifier = 0;
160  std::memcpy(&metadataIdentifier, header, sizeof(metadataIdentifier));
161 
162  // data_length is the next 4 bytes
163  uint32_t dataLength = 0;
164  std::memcpy(&dataLength, header + 4u, sizeof(dataLength));
165 
166  std::unique_ptr<unsigned char[]> packetData;
167  if (dataLength > 0)
168  {
169  packetData = std::make_unique<unsigned char[]>(dataLength);
170  long receivedLength = Sockets::Read(m_Socket[0].fd, packetData.get(), dataLength);
171  if (receivedLength < 0)
172  {
173  throw armnn::RuntimeException(std::string("Error occurred on recv: ") + strerror(errno));
174  }
175  if (dataLength != static_cast<uint32_t>(receivedLength))
176  {
177  // What do we do here if we can't read in a full packet?
178  throw armnn::RuntimeException("Invalid MIPE packet");
179  }
180  }
181 
182  return Packet(metadataIdentifier, dataLength, packetData);
183 }
184 
185 } // namespace profiling
186 } // namespace armnn
long Write(Socket s, const void *buf, size_t len)
bool WritePacket(const unsigned char *buffer, uint32_t length) final
bool Initialize()
Performs any required one-time setup.
long Read(Socket s, void *buf, size_t len)
int Poll(PollFd *fds, nfds_t numFds, int timeout)
int Ioctl(Socket s, unsigned long int cmd, void *arg)
bool SetNonBlocking(Socket s)