aboutsummaryrefslogtreecommitdiff
path: root/src/profiling/SocketProfilingConnection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/profiling/SocketProfilingConnection.cpp')
-rw-r--r--src/profiling/SocketProfilingConnection.cpp78
1 files changed, 72 insertions, 6 deletions
diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp
index 21a7a1d53c..188ca23e12 100644
--- a/src/profiling/SocketProfilingConnection.cpp
+++ b/src/profiling/SocketProfilingConnection.cpp
@@ -52,25 +52,91 @@ SocketProfilingConnection::SocketProfilingConnection()
bool SocketProfilingConnection::IsOpen()
{
- // Dummy return value, function not implemented
- return true;
+ if (m_Socket[0].fd > 0)
+ {
+ return true;
+ }
+ return false;
}
void SocketProfilingConnection::Close()
{
- // Function not implemented
+ if (0 == close(m_Socket[0].fd))
+ {
+ memset(m_Socket, 0, sizeof(m_Socket));
+ }
+ else
+ {
+ throw armnn::Exception(std::string(": Cannot close stream socket: ") + strerror(errno));
+ }
}
bool SocketProfilingConnection::WritePacket(const char* buffer, uint32_t length)
{
- // Dummy return value, function not implemented
+ if (-1 == write(m_Socket[0].fd, buffer, length))
+ {
+ return false;
+ }
return true;
}
Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
{
- // Dummy return value, function not implemented
- return {472580096, 0, nullptr};
+ // Poll for data on the socket or until timeout.
+ int pollResult = poll(m_Socket, 1, static_cast<int>(timeout));
+ if (pollResult > 0)
+ {
+ // Normal poll return but it could still contain an error signal.
+ if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
+ {
+ throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno));
+ }
+ else if (m_Socket[0].revents & (POLLIN)) // There is data to read.
+ {
+ // Read the header first.
+ char header[8];
+ if (8 != recv(m_Socket[0].fd, &header, sizeof header, 0))
+ {
+ // What do we do here if there's not a valid 8 byte header to read?
+ throw armnn::Exception(": Received packet did not contains a valid MIPE header. ");
+ }
+ // stream_metadata_identifier is the first 4 bytes.
+ uint32_t metadataIdentifier = static_cast<uint32_t>(header[0]) << 24 |
+ static_cast<uint32_t>(header[1]) << 16 |
+ static_cast<uint32_t>(header[2]) << 8 |
+ static_cast<uint32_t>(header[3]);
+ // data_length is the next 4 bytes.
+ uint32_t dataLength = static_cast<uint32_t>(header[4]) << 24 |
+ static_cast<uint32_t>(header[5]) << 16 |
+ static_cast<uint32_t>(header[6]) << 8 |
+ static_cast<uint32_t>(header[7]);
+
+ std::unique_ptr<char[]> packetData;
+ if (dataLength > 0)
+ {
+ packetData = std::make_unique<char[]>(dataLength);
+ }
+
+ if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0))
+ {
+ // What do we do here if we can't read in a full packet?
+ throw armnn::Exception(": Invalid MIPE packet.");
+ }
+ return {metadataIdentifier, dataLength, packetData};
+ }
+ else // Some unknown return signal.
+ {
+ throw armnn::Exception(": Poll returned an unexpected event." );
+ }
+ }
+ else if (pollResult == -1)
+ {
+ throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno));
+ }
+ else // it's 0 so a timeout.
+ {
+ throw armnn::Exception(": Timeout while reading from socket.");
+ }
}
} // namespace profiling