diff options
Diffstat (limited to 'src/profiling/SocketProfilingConnection.cpp')
-rw-r--r-- | src/profiling/SocketProfilingConnection.cpp | 120 |
1 files changed, 58 insertions, 62 deletions
diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp index 91d57cc9bd..47fc62f7f0 100644 --- a/src/profiling/SocketProfilingConnection.cpp +++ b/src/profiling/SocketProfilingConnection.cpp @@ -23,7 +23,7 @@ SocketProfilingConnection::SocketProfilingConnection() m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); if (m_Socket[0].fd == -1) { - throw armnn::Exception(std::string(": Socket construction failed: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Socket construction failed: ") + strerror(errno)); } // Connect to the named unix domain socket. @@ -35,7 +35,7 @@ SocketProfilingConnection::SocketProfilingConnection() if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un))) { close(m_Socket[0].fd); - throw armnn::Exception(std::string(": Cannot connect to stream socket: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Cannot connect to stream socket: ") + strerror(errno)); } // Our socket will only be interested in polling reads. @@ -46,96 +46,92 @@ SocketProfilingConnection::SocketProfilingConnection() if (0 != fcntl(m_Socket[0].fd, F_SETFL, currentFlags | O_NONBLOCK)) { close(m_Socket[0].fd); - throw armnn::Exception(std::string(": Failed to set socket as non blocking: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Failed to set socket as non blocking: ") + strerror(errno)); } } bool SocketProfilingConnection::IsOpen() { - if (m_Socket[0].fd > 0) - { - return true; - } - return false; + return m_Socket[0].fd > 0; } void SocketProfilingConnection::Close() { - if (0 == close(m_Socket[0].fd)) - { - memset(m_Socket, 0, sizeof(m_Socket)); - } - else + if (close(m_Socket[0].fd) != 0) { - throw armnn::Exception(std::string(": Cannot close stream socket: ") + strerror(errno)); + throw armnn::RuntimeException(std::string("Cannot close stream socket: ") + strerror(errno)); } + + memset(m_Socket, 0, sizeof(m_Socket)); } -bool SocketProfilingConnection::WritePacket(const char* buffer, uint32_t length) +bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length) { - if (-1 == write(m_Socket[0].fd, buffer, length)) + if (buffer == nullptr || length == 0) { return false; } - return true; + + return write(m_Socket[0].fd, buffer, length) != -1; } Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) { - // Poll for data on the socket or until timeout. + // Poll for data on the socket or until timeout occurs int pollResult = poll(m_Socket, 1, static_cast<int>(timeout)); - if (pollResult > 0) + + switch (pollResult) { - // Normal poll return but it could still contain an error signal. + case -1: // Error + throw armnn::RuntimeException(std::string("Read failure from socket: ") + strerror(errno)); + + case 0: // Timeout + throw armnn::RuntimeException("Timeout while reading from socket"); + + default: // Normal poll return but it could still contain an error signal + + // Check if the socket reported an error if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP)) { - throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno)); + throw armnn::Exception(std::string("Socket 0 reported an error: ") + strerror(errno)); + } + + // Check if there is data to read + if (!(m_Socket[0].revents & (POLLIN))) + { + // No data to read from the socket. Silently ignore and continue + return Packet(); } - else if (m_Socket[0].revents & (POLLIN)) // There is data to read. + + // There is data to read, read the header first + char header[8] = {}; + if (8 != recv(m_Socket[0].fd, &header, sizeof(header), 0)) { - // Read the header first. - char header[8]; - if (8 != recv(m_Socket[0].fd, &header, sizeof header, 0)) - { - // What do we do here if there's not a valid 8 byte header to read? - throw armnn::Exception(": Received packet did not contains a valid MIPE header. "); - } - // stream_metadata_identifier is the first 4 bytes. - uint32_t metadataIdentifier = static_cast<uint32_t>(header[0]) << 24 | - static_cast<uint32_t>(header[1]) << 16 | - static_cast<uint32_t>(header[2]) << 8 | - static_cast<uint32_t>(header[3]); - // data_length is the next 4 bytes. - uint32_t dataLength = static_cast<uint32_t>(header[4]) << 24 | - static_cast<uint32_t>(header[5]) << 16 | - static_cast<uint32_t>(header[6]) << 8 | - static_cast<uint32_t>(header[7]); - - std::unique_ptr<char[]> packetData; - if (dataLength > 0) - { - packetData = std::make_unique<char[]>(dataLength); - } - - if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0)) - { - // What do we do here if we can't read in a full packet? - throw armnn::Exception(": Invalid MIPE packet."); - } - return {metadataIdentifier, dataLength, packetData}; + // What do we do here if there's not a valid 8 byte header to read? + throw armnn::RuntimeException("The received packet did not contains a valid MIPE header"); } - else // Some unknown return signal. + + // stream_metadata_identifier is the first 4 bytes + uint32_t metadataIdentifier = 0; + std::memcpy(&metadataIdentifier, header, sizeof(metadataIdentifier)); + + // data_length is the next 4 bytes + uint32_t dataLength = 0; + std::memcpy(&dataLength, header + 4u, sizeof(dataLength)); + + std::unique_ptr<char[]> packetData; + if (dataLength > 0) { - throw armnn::Exception(": Poll returned an unexpected event." ); + packetData = std::make_unique<char[]>(dataLength); } - } - else if (pollResult == -1) - { - throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno)); - } - else // it's 0 so a timeout. - { - throw armnn::TimeoutException(": Timeout while reading from socket."); + + if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0)) + { + // What do we do here if we can't read in a full packet? + throw armnn::RuntimeException("Invalid MIPE packet"); + } + + return Packet(metadataIdentifier, dataLength, packetData); } } |