From d0613b56cea7eba0604e0548bddffd773a4eb554 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 9 Oct 2019 16:47:04 +0100 Subject: IVGCVSW-3937 Improve the Connection Acknowledged Handler * The Connection Acknowledged Handler should report an error is it's called while in a wrong state * Stopping the threads in the ProfilingService before having to start them again * Updated the unit tests to check the changes * Removed unnecessary Packet.cpp file * Fixed memory leak Signed-off-by: Matteo Martincigh Change-Id: I8c4d33b4d97994df86fe6c9f8c659f880ec64c16 --- src/profiling/test/ProfilingTests.cpp | 239 +++++--------------------- src/profiling/test/ProfilingTests.hpp | 200 +++++++++++++++++++++ src/profiling/test/SendCounterPacketTests.cpp | 6 +- src/profiling/test/SendCounterPacketTests.hpp | 49 ++++-- 4 files changed, 284 insertions(+), 210 deletions(-) create mode 100644 src/profiling/test/ProfilingTests.hpp (limited to 'src/profiling/test') diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index de92fb9eb0..80d99dd7ab 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -3,11 +3,10 @@ // SPDX-License-Identifier: MIT // -#include "SendCounterPacketTests.hpp" +#include "ProfilingTests.hpp" #include #include -#include #include #include #include @@ -19,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -27,21 +25,16 @@ #include -#include #include #include #include -#include #include #include -#include #include #include #include -#include -#include using namespace armnn::profiling; @@ -94,59 +87,6 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons) BOOST_CHECK(vect == expectedVect); } -class TestProfilingConnectionBase :public IProfilingConnection -{ -public: - TestProfilingConnectionBase() = default; - ~TestProfilingConnectionBase() = default; - - bool IsOpen() const override { return true; } - - void Close() override {} - - bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; } - - Packet ReadPacket(uint32_t timeout) override - { - std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); - std::unique_ptr packetData; - - // Return connection acknowledged packet - return { 65536, 0, packetData }; - } -}; - -class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase -{ -public: - Packet ReadPacket(uint32_t timeout) { - if (readRequests < 3) - { - readRequests++; - throw armnn::TimeoutException("Simulate a timeout"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); - std::unique_ptr packetData; - - // Return connection acknowledged packet after three timeouts - return { 65536, 0, packetData }; - } - -private: - int readRequests = 0; -}; - -class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase -{ -public: - - Packet ReadPacket(uint32_t timeout) - { - std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); - throw armnn::Exception(": Simulate a non timeout error"); - } -}; - BOOST_AUTO_TEST_CASE(CheckCommandHandler) { PacketVersionResolver packetVersionResolver; @@ -180,7 +120,7 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler) profilingStateMachine.TransitionToState(ProfilingState::NotConnected); profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck); // commandHandler1 should give up after one timeout - CommandHandler commandHandler1(1, + CommandHandler commandHandler1(10, true, commandHandlerRegistry, packetVersionResolver); @@ -204,32 +144,24 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler) break; } - std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } commandHandler1.Stop(); BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active); - CommandHandler commandHandler2(1, + CommandHandler commandHandler2(100, false, commandHandlerRegistry, packetVersionResolver); commandHandler2.Start(testProfilingConnectionArmnnError); - for (int i = 0; i < 100; i++) - { - if (!commandHandler2.IsRunning()) - { - // commandHandler2 should stop once it encounters a non timing error - return; - } - - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - } + // commandHandler2 should not stop once it encounters a non timing error + std::this_thread::sleep_for(std::chrono::milliseconds(500)); - BOOST_ERROR("commandHandler2 has failed to stop"); + BOOST_CHECK(commandHandler2.IsRunning()); commandHandler2.Stop(); } @@ -300,33 +232,6 @@ BOOST_AUTO_TEST_CASE(CheckPacketClass) BOOST_CHECK(packetTest4.GetPacketClass() == 5); } -// Create Derived Classes -class TestFunctorA : public CommandHandlerFunctor -{ -public: - using CommandHandlerFunctor::CommandHandlerFunctor; - - int GetCount() { return m_Count; } - - void operator()(const Packet& packet) override - { - m_Count++; - } - -private: - int m_Count = 0; -}; - -class TestFunctorB : public TestFunctorA -{ - using TestFunctorA::TestFunctorA; -}; - -class TestFunctorC : public TestFunctorA -{ - using TestFunctorA::TestFunctorA; -}; - BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor) { // Hard code the version as it will be the same during a single profiling session @@ -455,6 +360,7 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver) BOOST_TEST(resolvedVersion == expectedVersion); } } + void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states) { ProfilingState newState = ProfilingState::NotConnected; @@ -664,32 +570,6 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled) BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised); } -struct LogLevelSwapper -{ -public: - LogLevelSwapper(armnn::LogSeverity severity) - { - // Set the new log level - armnn::ConfigureLogging(true, true, severity); - } - ~LogLevelSwapper() - { - // The default log level for unit tests is "Fatal" - armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); - } -}; - -struct CoutRedirect -{ -public: - CoutRedirect(std::streambuf* newStreamBuffer) - : old(std::cout.rdbuf(newStreamBuffer)) {} - ~CoutRedirect() { std::cout.rdbuf(old); } - -private: - std::streambuf* old; -}; - BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled) { // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output @@ -705,7 +585,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled) // Redirect the output to a local stream so that we can parse the warning message std::stringstream ss; - CoutRedirect coutRedirect(ss.rdbuf()); + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); profilingService.Update(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } @@ -729,7 +609,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime) // Redirect the output to a local stream so that we can parse the warning message std::stringstream ss; - CoutRedirect coutRedirect(ss.rdbuf()); + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); profilingService.Update(); BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused")); } @@ -1949,16 +1829,18 @@ BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged) profilingState.TransitionToState(ProfilingState::WaitingForAck); BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::WaitingForAck); // command handler received packet on ProfilingState::WaitingForAck - commandHandler(packetA); + BOOST_CHECK_NO_THROW(commandHandler(packetA)); BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active); // command handler received packet on ProfilingState::Active - commandHandler(packetA); + BOOST_CHECK_NO_THROW(commandHandler(packetA)); BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active); // command handler received different packet const uint32_t differentPacketId = 0x40000; Packet packetB(differentPacketId, dataLength1, uniqueData1); + profilingState.TransitionToState(ProfilingState::NotConnected); + profilingState.TransitionToState(ProfilingState::WaitingForAck); ConnectionAcknowledgedCommandHandler differentCommandHandler(differentPacketId, version, profilingState); BOOST_CHECK_THROW(differentCommandHandler(packetB), armnn::Exception); } @@ -2333,62 +2215,17 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1) BOOST_TEST(categoryRecordOffset == 44); } -class MockProfilingConnectionFactory : public IProfilingConnectionFactory -{ -public: - MockProfilingConnectionFactory() - : m_MockProfilingConnection(new MockProfilingConnection()) - {} - - IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override - { - return std::unique_ptr(m_MockProfilingConnection); - } - - MockProfilingConnection* GetMockProfilingConnection() { return m_MockProfilingConnection; } - -private: - MockProfilingConnection* m_MockProfilingConnection; -}; - -class SwapProfilingConnectionFactoryHelper : public ProfilingService -{ -public: - SwapProfilingConnectionFactoryHelper() - : ProfilingService() - , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory()) - , m_BackupProfilingConnectionFactory(nullptr) - { - SwapProfilingConnectionFactory(ProfilingService::Instance(), - m_MockProfilingConnectionFactory.get(), - m_BackupProfilingConnectionFactory); - } - ~SwapProfilingConnectionFactoryHelper() - { - IProfilingConnectionFactory* temp = nullptr; - SwapProfilingConnectionFactory(ProfilingService::Instance(), - m_BackupProfilingConnectionFactory, - temp); - } - - IProfilingConnectionFactory* GetMockProfilingConnectionFactory() { return m_MockProfilingConnectionFactory.get(); } - -private: - IProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; - IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; -}; - BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) { // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning); + // Swap the profiling connection factory in the profiling service instance with our mock one SwapProfilingConnectionFactoryHelper helper; - MockProfilingConnectionFactory* mockProfilingConnectionFactory = - boost::polymorphic_downcast(helper.GetMockProfilingConnectionFactory()); - BOOST_CHECK(mockProfilingConnectionFactory); - MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection(); - BOOST_CHECK(mockProfilingConnection); + + // Redirect the standard output to a local stream so that we can parse the warning message + std::stringstream ss; + StreamRedirector streamRedirector(std::cout, ss.rdbuf()); // Calculate the size of a Stream Metadata packet std::string processName = GetProcessName().substr(0, 60); @@ -2408,15 +2245,15 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); profilingService.Update(); - // Redirect the output to a local stream so that we can parse the warning message - std::stringstream ss; - CoutRedirect coutRedirect(ss.rdbuf()); - // Wait for a bit to make sure that we get the packet std::this_thread::sleep_for(std::chrono::milliseconds(100)); + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + // Check that the mock profiling connection contains one Stream Metadata packet - const std::vector& writtenData = mockProfilingConnection->GetWrittenData(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); BOOST_TEST(writtenData.size() == 1); BOOST_TEST(writtenData[0] == streamMetadataPacketsize); @@ -2433,7 +2270,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16); - // Connection Acknowledged Packet + // Create the Connection Acknowledged Packet Packet connectionAcknowledgedPacket(header); // Write the packet to the mock profiling connection @@ -2441,23 +2278,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket) // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that // the Connection Acknowledged packet gets processed by the profiling service - std::this_thread::sleep_for(std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::seconds(2)); // Check that the expected error has occurred and logged to the standard output BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist")); // The Connection Acknowledged Command Handler should not have updated the profiling state BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); } BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) { + // Swap the profiling connection factory in the profiling service instance with our mock one SwapProfilingConnectionFactoryHelper helper; - MockProfilingConnectionFactory* mockProfilingConnectionFactory = - boost::polymorphic_downcast(helper.GetMockProfilingConnectionFactory()); - BOOST_CHECK(mockProfilingConnectionFactory); - MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection(); - BOOST_CHECK(mockProfilingConnection); // Calculate the size of a Stream Metadata packet std::string processName = GetProcessName().substr(0, 60); @@ -2480,8 +2317,12 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) // Wait for a bit to make sure that we get the packet std::this_thread::sleep_for(std::chrono::milliseconds(100)); + // Get the mock profiling connection + MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection(); + BOOST_CHECK(mockProfilingConnection); + // Check that the mock profiling connection contains one Stream Metadata packet - const std::vector& writtenData = mockProfilingConnection->GetWrittenData(); + const std::vector writtenData = mockProfilingConnection->GetWrittenData(); BOOST_TEST(writtenData.size() == 1); BOOST_TEST(writtenData[0] == streamMetadataPacketsize); @@ -2498,7 +2339,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16); - // Connection Acknowledged Packet + // Create the Connection Acknowledged Packet Packet connectionAcknowledgedPacket(header); // Write the packet to the mock profiling connection @@ -2506,10 +2347,14 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket) // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that // the Connection Acknowledged packet gets processed by the profiling service - std::this_thread::sleep_for(std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::seconds(2)); // The Connection Acknowledged Command Handler should have updated the profiling state accordingly BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active); + + // Reset the profiling service to stop any running thread + options.m_EnableProfiling = false; + profilingService.ResetExternalProfilingOptions(options, true); } BOOST_AUTO_TEST_SUITE_END() diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp new file mode 100644 index 0000000000..3e6cf63efe --- /dev/null +++ b/src/profiling/test/ProfilingTests.hpp @@ -0,0 +1,200 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "SendCounterPacketTests.hpp" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace armnn +{ + +namespace profiling +{ + +struct LogLevelSwapper +{ +public: + LogLevelSwapper(armnn::LogSeverity severity) + { + // Set the new log level + armnn::ConfigureLogging(true, true, severity); + } + ~LogLevelSwapper() + { + // The default log level for unit tests is "Fatal" + armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); + } +}; + +struct CoutRedirect +{ +public: + CoutRedirect(std::streambuf* newStreamBuffer) + : m_Old(std::cout.rdbuf(newStreamBuffer)) {} + ~CoutRedirect() { std::cout.rdbuf(m_Old); } + +private: + std::streambuf* m_Old; +}; + +struct StreamRedirector +{ +public: + StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer) + : m_Stream(stream) + , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer)) + {} + ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); } + +private: + std::ostream& m_Stream; + std::streambuf* m_BackupBuffer; +}; + +class TestProfilingConnectionBase : public IProfilingConnection +{ +public: + TestProfilingConnectionBase() = default; + ~TestProfilingConnectionBase() = default; + + bool IsOpen() const override { return true; } + + void Close() override {} + + bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; } + + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + // Return connection acknowledged packet + std::unique_ptr packetData; + return Packet(65536, 0, packetData); + } +}; + +class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase +{ +public: + TestProfilingConnectionTimeoutError() + : m_ReadRequests(0) + {} + + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + if (m_ReadRequests < 3) + { + m_ReadRequests++; + throw armnn::TimeoutException("Simulate a timeout error\n"); + } + + // Return connection acknowledged packet after three timeouts + std::unique_ptr packetData; + return Packet(65536, 0, packetData); + } + +private: + int m_ReadRequests; +}; + +class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase +{ +public: + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + throw armnn::Exception("Simulate a non-timeout error"); + } +}; + +class TestFunctorA : public CommandHandlerFunctor +{ +public: + using CommandHandlerFunctor::CommandHandlerFunctor; + + int GetCount() { return m_Count; } + + void operator()(const Packet& packet) override + { + m_Count++; + } + +private: + int m_Count = 0; +}; + +class TestFunctorB : public TestFunctorA +{ + using TestFunctorA::TestFunctorA; +}; + +class TestFunctorC : public TestFunctorA +{ + using TestFunctorA::TestFunctorA; +}; + +class MockProfilingConnectionFactory : public IProfilingConnectionFactory +{ +public: + IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override + { + return std::make_unique(); + } +}; + +class SwapProfilingConnectionFactoryHelper : public ProfilingService +{ +public: + using MockProfilingConnectionFactoryPtr = std::unique_ptr; + + SwapProfilingConnectionFactoryHelper() + : ProfilingService() + , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory()) + , m_BackupProfilingConnectionFactory(nullptr) + { + BOOST_CHECK(m_MockProfilingConnectionFactory); + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_MockProfilingConnectionFactory.get(), + m_BackupProfilingConnectionFactory); + BOOST_CHECK(m_BackupProfilingConnectionFactory); + } + ~SwapProfilingConnectionFactoryHelper() + { + BOOST_CHECK(m_BackupProfilingConnectionFactory); + IProfilingConnectionFactory* temp = nullptr; + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_BackupProfilingConnectionFactory, + temp); + } + + MockProfilingConnection* GetMockProfilingConnection() + { + IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance()); + return boost::polymorphic_downcast(profilingConnection); + } + +private: + MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; + IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; +}; + +} // namespace profiling + +} // namespace armnn diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp index 1216420383..00dad38078 100644 --- a/src/profiling/test/SendCounterPacketTests.cpp +++ b/src/profiling/test/SendCounterPacketTests.cpp @@ -2322,7 +2322,7 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1) BOOST_TEST(reservedBuffer.get()); // Check that data was actually written to the profiling connection in any order - const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); + const std::vector writtenData = mockProfilingConnection.GetWrittenData(); BOOST_TEST(writtenData.size() == 3); bool foundStreamMetaDataPacket = std::find(writtenData.begin(), writtenData.end(), streamMetadataPacketsize) != writtenData.end(); @@ -2391,7 +2391,7 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket3) BOOST_CHECK_NO_THROW(sendCounterPacket.Stop()); // Check that the buffer contains one Stream Metadata packet - const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); + const std::vector writtenData = mockProfilingConnection.GetWrittenData(); BOOST_TEST(writtenData.size() == 1); BOOST_TEST(writtenData[0] == streamMetadataPacketsize); } @@ -2420,7 +2420,7 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket4) BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck)); // Check that the buffer contains one Stream Metadata packet - const std::vector& writtenData = mockProfilingConnection.GetWrittenData(); + const std::vector writtenData = mockProfilingConnection.GetWrittenData(); BOOST_TEST(writtenData.size() == 1); BOOST_TEST(writtenData[0] == streamMetadataPacketsize); diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp index 48bab025dd..871ca74124 100644 --- a/src/profiling/test/SendCounterPacketTests.hpp +++ b/src/profiling/test/SendCounterPacketTests.hpp @@ -12,6 +12,7 @@ #include #include +#include #include namespace armnn @@ -19,6 +20,7 @@ namespace armnn namespace profiling { + class MockProfilingConnection : public IProfilingConnection { public: @@ -28,9 +30,19 @@ public: , m_Packet() {} - bool IsOpen() const override { return m_IsOpen; } + bool IsOpen() const override + { + std::lock_guard lock(m_Mutex); + + return m_IsOpen; + } + + void Close() override + { + std::lock_guard lock(m_Mutex); - void Close() override { m_IsOpen = false; } + m_IsOpen = false; + } bool WritePacket(const unsigned char* buffer, uint32_t length) override { @@ -39,11 +51,15 @@ public: return false; } + std::lock_guard lock(m_Mutex); + m_WrittenData.push_back(length); return true; } bool WritePacket(Packet&& packet) { + std::lock_guard lock(m_Mutex); + m_Packet = std::move(packet); return true; } @@ -51,19 +67,32 @@ public: Packet ReadPacket(uint32_t timeout) override { // Simulate a delay in the reading process - std::this_thread::sleep_for(std::chrono::milliseconds(500)); + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + std::lock_guard lock(m_Mutex); return std::move(m_Packet); } - const std::vector& GetWrittenData() const { return m_WrittenData; } + const std::vector GetWrittenData() const + { + std::lock_guard lock(m_Mutex); + + return m_WrittenData; + } + + void Clear() + { + std::lock_guard lock(m_Mutex); - void Clear() { m_WrittenData.clear(); } + m_WrittenData.clear(); + } private: bool m_IsOpen; std::vector m_WrittenData; Packet m_Packet; + mutable std::mutex m_Mutex; }; class MockPacketBuffer : public IPacketBuffer @@ -162,7 +191,7 @@ public: IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); reservedSize = 0; if (requestedSize > m_MaxBufferSize) @@ -176,7 +205,7 @@ public: void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); packetBuffer->Commit(size); m_BufferList.push_back(std::move(packetBuffer)); @@ -185,14 +214,14 @@ public: void Release(IPacketBufferPtr& packetBuffer) override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); packetBuffer->Release(); } IPacketBufferPtr GetReadableBuffer() override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); if (m_BufferList.empty()) { @@ -206,7 +235,7 @@ public: void MarkRead(IPacketBufferPtr& packetBuffer) override { - std::unique_lock lock(m_Mutex); + std::lock_guard lock(m_Mutex); m_ReadSize += packetBuffer->GetSize(); packetBuffer->MarkRead(); -- cgit v1.2.1