From a0c7871cf140d1e9cf59a213626ee534c0122c7f Mon Sep 17 00:00:00 2001 From: FinnWilliamsArm Date: Mon, 16 Sep 2019 12:06:47 +0100 Subject: IVGCVSW-3826: Implement IProfiling functions !armnn:1814 Signed-off-by: Teresa Charlin Signed-off-by: FinnWilliamsArm Change-Id: I82c7453d7969880e321572637adc0fb9c0e5fd7b --- src/profiling/Packet.cpp | 4 +- src/profiling/Packet.hpp | 20 ++++--- src/profiling/SocketProfilingConnection.cpp | 78 ++++++++++++++++++++++++--- src/profiling/test/ProfilingTests.cpp | 81 ++++++++++++++++++++--------- 4 files changed, 143 insertions(+), 40 deletions(-) diff --git a/src/profiling/Packet.cpp b/src/profiling/Packet.cpp index 44d5ac19e9..4cfa42bbc9 100644 --- a/src/profiling/Packet.cpp +++ b/src/profiling/Packet.cpp @@ -31,9 +31,9 @@ std::uint32_t Packet::GetLength() const return m_Length; } -const char* Packet::GetData() const +const char* const Packet::GetData() const { - return m_Data; + return m_Data.get(); } std::uint32_t Packet::GetPacketClass() const diff --git a/src/profiling/Packet.hpp b/src/profiling/Packet.hpp index c5e7f3c029..1e047a6511 100644 --- a/src/profiling/Packet.hpp +++ b/src/profiling/Packet.hpp @@ -17,10 +17,8 @@ namespace profiling class Packet { public: - Packet(uint32_t header, uint32_t length, const char* data) - : m_Header(header), - m_Length(length), - m_Data(data) + Packet(uint32_t header, uint32_t length, std::unique_ptr& data) + : m_Header(header), m_Length(length), m_Data(std::move(data)) { m_PacketId = ((header >> 16) & 1023); m_PacketFamily = (header >> 26); @@ -31,11 +29,21 @@ public: } } + Packet(Packet&& other) : + m_Header(other.m_Header), + m_PacketFamily(other.m_PacketFamily), + m_PacketId(other.m_PacketId), + m_Length(other.m_Length), + m_Data(std::move(other.m_Data)){}; + + Packet(const Packet& other) = delete; + Packet& operator=(const Packet&) = delete; + uint32_t GetHeader() const; uint32_t GetPacketFamily() const; uint32_t GetPacketId() const; uint32_t GetLength() const; - const char* GetData() const; + const char* const GetData() const; uint32_t GetPacketClass() const; uint32_t GetPacketType() const; @@ -45,7 +53,7 @@ private: uint32_t m_PacketFamily; uint32_t m_PacketId; uint32_t m_Length; - const char* m_Data; + std::unique_ptr m_Data; }; } // namespace profiling diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp index 21a7a1d53c..188ca23e12 100644 --- a/src/profiling/SocketProfilingConnection.cpp +++ b/src/profiling/SocketProfilingConnection.cpp @@ -52,25 +52,91 @@ SocketProfilingConnection::SocketProfilingConnection() bool SocketProfilingConnection::IsOpen() { - // Dummy return value, function not implemented - return true; + if (m_Socket[0].fd > 0) + { + return true; + } + return false; } void SocketProfilingConnection::Close() { - // Function not implemented + if (0 == close(m_Socket[0].fd)) + { + memset(m_Socket, 0, sizeof(m_Socket)); + } + else + { + throw armnn::Exception(std::string(": Cannot close stream socket: ") + strerror(errno)); + } } bool SocketProfilingConnection::WritePacket(const char* buffer, uint32_t length) { - // Dummy return value, function not implemented + if (-1 == write(m_Socket[0].fd, buffer, length)) + { + return false; + } return true; } Packet SocketProfilingConnection::ReadPacket(uint32_t timeout) { - // Dummy return value, function not implemented - return {472580096, 0, nullptr}; + // Poll for data on the socket or until timeout. + int pollResult = poll(m_Socket, 1, static_cast(timeout)); + if (pollResult > 0) + { + // Normal poll return but it could still contain an error signal. + if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP)) + { + throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno)); + } + else if (m_Socket[0].revents & (POLLIN)) // There is data to read. + { + // Read the header first. + char header[8]; + if (8 != recv(m_Socket[0].fd, &header, sizeof header, 0)) + { + // What do we do here if there's not a valid 8 byte header to read? + throw armnn::Exception(": Received packet did not contains a valid MIPE header. "); + } + // stream_metadata_identifier is the first 4 bytes. + uint32_t metadataIdentifier = static_cast(header[0]) << 24 | + static_cast(header[1]) << 16 | + static_cast(header[2]) << 8 | + static_cast(header[3]); + // data_length is the next 4 bytes. + uint32_t dataLength = static_cast(header[4]) << 24 | + static_cast(header[5]) << 16 | + static_cast(header[6]) << 8 | + static_cast(header[7]); + + std::unique_ptr packetData; + if (dataLength > 0) + { + packetData = std::make_unique(dataLength); + } + + if (dataLength != recv(m_Socket[0].fd, packetData.get(), dataLength, 0)) + { + // What do we do here if we can't read in a full packet? + throw armnn::Exception(": Invalid MIPE packet."); + } + return {metadataIdentifier, dataLength, packetData}; + } + else // Some unknown return signal. + { + throw armnn::Exception(": Poll returned an unexpected event." ); + } + } + else if (pollResult == -1) + { + throw armnn::Exception(std::string(": Read failure from socket: ") + strerror(errno)); + } + else // it's 0 so a timeout. + { + throw armnn::Exception(": Timeout while reading from socket."); + } } } // namespace profiling diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 4913dde6ee..55524a4dfe 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -114,21 +114,39 @@ BOOST_AUTO_TEST_CASE(CheckEncodeVersion) BOOST_AUTO_TEST_CASE(CheckPacketClass) { - const char* data = "test"; - unsigned int length = static_cast(std::strlen(data)); - - Packet packetTest1(472580096,length,data); - BOOST_CHECK_THROW(Packet packetTest2(472580096,0,""), armnn::Exception); - - Packet packetTest3(472580096,0, nullptr); - - BOOST_CHECK(packetTest1.GetLength() == length); - BOOST_CHECK(packetTest1.GetData() == data); - - BOOST_CHECK(packetTest1.GetPacketFamily() == 7); - BOOST_CHECK(packetTest1.GetPacketId() == 43); - BOOST_CHECK(packetTest1.GetPacketType() == 3); - BOOST_CHECK(packetTest1.GetPacketClass() == 5); + uint32_t length = 4; + std::unique_ptr packetData0 = std::make_unique(length); + std::unique_ptr packetData1 = std::make_unique(0); + std::unique_ptr nullPacketData; + + Packet packetTest0(472580096, length, packetData0); + + BOOST_CHECK(packetTest0.GetHeader() == 472580096); + BOOST_CHECK(packetTest0.GetPacketFamily() == 7); + BOOST_CHECK(packetTest0.GetPacketId() == 43); + BOOST_CHECK(packetTest0.GetLength() == length); + BOOST_CHECK(packetTest0.GetPacketType() == 3); + BOOST_CHECK(packetTest0.GetPacketClass() == 5); + + BOOST_CHECK_THROW(Packet packetTest1(472580096, 0, packetData1), armnn::Exception); + BOOST_CHECK_NO_THROW(Packet packetTest2(472580096, 0, nullPacketData)); + + Packet packetTest3(472580096, 0, nullPacketData); + BOOST_CHECK(packetTest3.GetLength() == 0); + BOOST_CHECK(packetTest3.GetData() == nullptr); + + const char* packetTest0Data = packetTest0.GetData(); + Packet packetTest4(std::move(packetTest0)); + + BOOST_CHECK(packetTest0.GetData() == nullptr); + BOOST_CHECK(packetTest4.GetData() == packetTest0Data); + + BOOST_CHECK(packetTest4.GetHeader() == 472580096); + BOOST_CHECK(packetTest4.GetPacketFamily() == 7); + BOOST_CHECK(packetTest4.GetPacketId() == 43); + BOOST_CHECK(packetTest4.GetLength() == length); + BOOST_CHECK(packetTest4.GetPacketType() == 3); + BOOST_CHECK(packetTest4.GetPacketClass() == 5); } // Create Derived Classes @@ -186,9 +204,13 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor) it++; BOOST_CHECK(it->first==keyC); - Packet packetA(500000000, 0, nullptr); - Packet packetB(600000000, 0, nullptr); - Packet packetC(400000000, 0, nullptr); + std::unique_ptr packetDataA; + std::unique_ptr packetDataB; + std::unique_ptr packetDataC; + + Packet packetA(500000000, 0, packetDataA); + Packet packetB(600000000, 0, packetDataB); + Packet packetC(400000000, 0, packetDataC); // Check the correct operator of derived class is called registry.at(CommandHandlerKey(packetA.GetPacketId(), version))->operator()(packetA); @@ -224,9 +246,13 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerRegistry) registry.RegisterFunctor(&testFunctorB, testFunctorB.GetPacketId(), testFunctorB.GetVersion()); registry.RegisterFunctor(&testFunctorC, testFunctorC.GetPacketId(), testFunctorC.GetVersion()); - Packet packetA(500000000, 0, nullptr); - Packet packetB(600000000, 0, nullptr); - Packet packetC(400000000, 0, nullptr); + std::unique_ptr packetDataA; + std::unique_ptr packetDataB; + std::unique_ptr packetDataC; + + Packet packetA(500000000, 0, packetDataA); + Packet packetB(600000000, 0, packetDataB); + Packet packetC(400000000, 0, packetDataC); // Check the correct operator of derived class is called registry.GetFunctor(packetA.GetPacketId(), version)->operator()(packetA); @@ -561,16 +587,18 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) // Data with period and counters uint32_t period1 = 10; uint32_t dataLength1 = 8; - unsigned char data1[dataLength1]; uint32_t offset = 0; + std::unique_ptr uniqueData1 = std::make_unique(dataLength1); + unsigned char* data1 = reinterpret_cast(uniqueData1.get()); + WriteUint32(data1, offset, period1); offset += sizeOfUint32; WriteUint16(data1, offset, 4000); offset += sizeOfUint16; WriteUint16(data1, offset, 5000); - Packet packetA(packetId, dataLength1, reinterpret_cast(data1)); + Packet packetA(packetId, dataLength1, uniqueData1); PeriodicCounterSelectionCommandHandler commandHandler(packetId, version, holder, captureThread, sendCounterPacket); @@ -611,11 +639,12 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData) // Data with period only uint32_t period2 = 11; uint32_t dataLength2 = 4; - unsigned char data2[dataLength2]; - WriteUint32(data2, 0, period2); + std::unique_ptr uniqueData2 = std::make_unique(dataLength2); + + WriteUint32(reinterpret_cast(uniqueData2.get()), 0, period2); - Packet packetB(packetId, dataLength2, reinterpret_cast(data2)); + Packet packetB(packetId, dataLength2, uniqueData2); commandHandler(packetB); -- cgit v1.2.1