aboutsummaryrefslogtreecommitdiff
path: root/src/profiling/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/profiling/test')
-rw-r--r--src/profiling/test/ProfilingTests.cpp239
-rw-r--r--src/profiling/test/ProfilingTests.hpp200
-rw-r--r--src/profiling/test/SendCounterPacketTests.cpp6
-rw-r--r--src/profiling/test/SendCounterPacketTests.hpp49
4 files changed, 284 insertions, 210 deletions
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index de92fb9eb0..80d99dd7ab 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -3,11 +3,10 @@
// SPDX-License-Identifier: MIT
//
-#include "SendCounterPacketTests.hpp"
+#include "ProfilingTests.hpp"
#include <CommandHandler.hpp>
#include <CommandHandlerKey.hpp>
-#include <CommandHandlerFunctor.hpp>
#include <CommandHandlerRegistry.hpp>
#include <ConnectionAcknowledgedCommandHandler.hpp>
#include <CounterDirectory.hpp>
@@ -19,7 +18,6 @@
#include <PeriodicCounterCapture.hpp>
#include <PeriodicCounterSelectionCommandHandler.hpp>
#include <ProfilingStateMachine.hpp>
-#include <ProfilingService.hpp>
#include <ProfilingUtils.hpp>
#include <RequestCounterDirectoryCommandHandler.hpp>
#include <Runtime.hpp>
@@ -27,21 +25,16 @@
#include <armnn/Conversion.hpp>
-#include <Logging.hpp>
#include <armnn/Utils.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/numeric/conversion/cast.hpp>
-#include <boost/test/unit_test.hpp>
#include <cstdint>
#include <cstring>
-#include <iostream>
#include <limits>
#include <map>
#include <random>
-#include <thread>
-#include <chrono>
using namespace armnn::profiling;
@@ -94,59 +87,6 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
BOOST_CHECK(vect == expectedVect);
}
-class TestProfilingConnectionBase :public IProfilingConnection
-{
-public:
- TestProfilingConnectionBase() = default;
- ~TestProfilingConnectionBase() = default;
-
- bool IsOpen() const override { return true; }
-
- void Close() override {}
-
- bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
-
- Packet ReadPacket(uint32_t timeout) override
- {
- std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
- std::unique_ptr<char[]> packetData;
-
- // Return connection acknowledged packet
- return { 65536, 0, packetData };
- }
-};
-
-class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
-{
-public:
- Packet ReadPacket(uint32_t timeout) {
- if (readRequests < 3)
- {
- readRequests++;
- throw armnn::TimeoutException("Simulate a timeout");
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
- std::unique_ptr<char[]> packetData;
-
- // Return connection acknowledged packet after three timeouts
- return { 65536, 0, packetData };
- }
-
-private:
- int readRequests = 0;
-};
-
-class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase
-{
-public:
-
- Packet ReadPacket(uint32_t timeout)
- {
- std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
- throw armnn::Exception(": Simulate a non timeout error");
- }
-};
-
BOOST_AUTO_TEST_CASE(CheckCommandHandler)
{
PacketVersionResolver packetVersionResolver;
@@ -180,7 +120,7 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler)
profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
// commandHandler1 should give up after one timeout
- CommandHandler commandHandler1(1,
+ CommandHandler commandHandler1(10,
true,
commandHandlerRegistry,
packetVersionResolver);
@@ -204,32 +144,24 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler)
break;
}
- std::this_thread::sleep_for(std::chrono::milliseconds(5));
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
commandHandler1.Stop();
BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
- CommandHandler commandHandler2(1,
+ CommandHandler commandHandler2(100,
false,
commandHandlerRegistry,
packetVersionResolver);
commandHandler2.Start(testProfilingConnectionArmnnError);
- for (int i = 0; i < 100; i++)
- {
- if (!commandHandler2.IsRunning())
- {
- // commandHandler2 should stop once it encounters a non timing error
- return;
- }
-
- std::this_thread::sleep_for(std::chrono::milliseconds(5));
- }
+ // commandHandler2 should not stop once it encounters a non timing error
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
- BOOST_ERROR("commandHandler2 has failed to stop");
+ BOOST_CHECK(commandHandler2.IsRunning());
commandHandler2.Stop();
}
@@ -300,33 +232,6 @@ BOOST_AUTO_TEST_CASE(CheckPacketClass)
BOOST_CHECK(packetTest4.GetPacketClass() == 5);
}
-// Create Derived Classes
-class TestFunctorA : public CommandHandlerFunctor
-{
-public:
- using CommandHandlerFunctor::CommandHandlerFunctor;
-
- int GetCount() { return m_Count; }
-
- void operator()(const Packet& packet) override
- {
- m_Count++;
- }
-
-private:
- int m_Count = 0;
-};
-
-class TestFunctorB : public TestFunctorA
-{
- using TestFunctorA::TestFunctorA;
-};
-
-class TestFunctorC : public TestFunctorA
-{
- using TestFunctorA::TestFunctorA;
-};
-
BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
{
// Hard code the version as it will be the same during a single profiling session
@@ -455,6 +360,7 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
BOOST_TEST(resolvedVersion == expectedVersion);
}
}
+
void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
{
ProfilingState newState = ProfilingState::NotConnected;
@@ -664,32 +570,6 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled)
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
}
-struct LogLevelSwapper
-{
-public:
- LogLevelSwapper(armnn::LogSeverity severity)
- {
- // Set the new log level
- armnn::ConfigureLogging(true, true, severity);
- }
- ~LogLevelSwapper()
- {
- // The default log level for unit tests is "Fatal"
- armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
- }
-};
-
-struct CoutRedirect
-{
-public:
- CoutRedirect(std::streambuf* newStreamBuffer)
- : old(std::cout.rdbuf(newStreamBuffer)) {}
- ~CoutRedirect() { std::cout.rdbuf(old); }
-
-private:
- std::streambuf* old;
-};
-
BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
{
// Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
@@ -705,7 +585,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
// Redirect the output to a local stream so that we can parse the warning message
std::stringstream ss;
- CoutRedirect coutRedirect(ss.rdbuf());
+ StreamRedirector streamRedirector(std::cout, ss.rdbuf());
profilingService.Update();
BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
}
@@ -729,7 +609,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime)
// Redirect the output to a local stream so that we can parse the warning message
std::stringstream ss;
- CoutRedirect coutRedirect(ss.rdbuf());
+ StreamRedirector streamRedirector(std::cout, ss.rdbuf());
profilingService.Update();
BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
}
@@ -1949,16 +1829,18 @@ BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged)
profilingState.TransitionToState(ProfilingState::WaitingForAck);
BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::WaitingForAck);
// command handler received packet on ProfilingState::WaitingForAck
- commandHandler(packetA);
+ BOOST_CHECK_NO_THROW(commandHandler(packetA));
BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
// command handler received packet on ProfilingState::Active
- commandHandler(packetA);
+ BOOST_CHECK_NO_THROW(commandHandler(packetA));
BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
// command handler received different packet
const uint32_t differentPacketId = 0x40000;
Packet packetB(differentPacketId, dataLength1, uniqueData1);
+ profilingState.TransitionToState(ProfilingState::NotConnected);
+ profilingState.TransitionToState(ProfilingState::WaitingForAck);
ConnectionAcknowledgedCommandHandler differentCommandHandler(differentPacketId, version, profilingState);
BOOST_CHECK_THROW(differentCommandHandler(packetB), armnn::Exception);
}
@@ -2333,62 +2215,17 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
BOOST_TEST(categoryRecordOffset == 44);
}
-class MockProfilingConnectionFactory : public IProfilingConnectionFactory
-{
-public:
- MockProfilingConnectionFactory()
- : m_MockProfilingConnection(new MockProfilingConnection())
- {}
-
- IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
- {
- return std::unique_ptr<MockProfilingConnection>(m_MockProfilingConnection);
- }
-
- MockProfilingConnection* GetMockProfilingConnection() { return m_MockProfilingConnection; }
-
-private:
- MockProfilingConnection* m_MockProfilingConnection;
-};
-
-class SwapProfilingConnectionFactoryHelper : public ProfilingService
-{
-public:
- SwapProfilingConnectionFactoryHelper()
- : ProfilingService()
- , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
- , m_BackupProfilingConnectionFactory(nullptr)
- {
- SwapProfilingConnectionFactory(ProfilingService::Instance(),
- m_MockProfilingConnectionFactory.get(),
- m_BackupProfilingConnectionFactory);
- }
- ~SwapProfilingConnectionFactoryHelper()
- {
- IProfilingConnectionFactory* temp = nullptr;
- SwapProfilingConnectionFactory(ProfilingService::Instance(),
- m_BackupProfilingConnectionFactory,
- temp);
- }
-
- IProfilingConnectionFactory* GetMockProfilingConnectionFactory() { return m_MockProfilingConnectionFactory.get(); }
-
-private:
- IProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
- IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
-};
-
BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
{
// Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+ // Swap the profiling connection factory in the profiling service instance with our mock one
SwapProfilingConnectionFactoryHelper helper;
- MockProfilingConnectionFactory* mockProfilingConnectionFactory =
- boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory());
- BOOST_CHECK(mockProfilingConnectionFactory);
- MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection();
- BOOST_CHECK(mockProfilingConnection);
+
+ // Redirect the standard output to a local stream so that we can parse the warning message
+ std::stringstream ss;
+ StreamRedirector streamRedirector(std::cout, ss.rdbuf());
// Calculate the size of a Stream Metadata packet
std::string processName = GetProcessName().substr(0, 60);
@@ -2408,15 +2245,15 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
profilingService.Update();
- // Redirect the output to a local stream so that we can parse the warning message
- std::stringstream ss;
- CoutRedirect coutRedirect(ss.rdbuf());
-
// Wait for a bit to make sure that we get the packet
std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
// Check that the mock profiling connection contains one Stream Metadata packet
- const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
@@ -2433,7 +2270,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
uint32_t header = ((packetFamily & 0x0000003F) << 26) |
((packetId & 0x000003FF) << 16);
- // Connection Acknowledged Packet
+ // Create the Connection Acknowledged Packet
Packet connectionAcknowledgedPacket(header);
// Write the packet to the mock profiling connection
@@ -2441,23 +2278,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
// Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
// the Connection Acknowledged packet gets processed by the profiling service
- std::this_thread::sleep_for(std::chrono::seconds(1));
+ std::this_thread::sleep_for(std::chrono::seconds(2));
// Check that the expected error has occurred and logged to the standard output
BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist"));
// The Connection Acknowledged Command Handler should not have updated the profiling state
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
}
BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
{
+ // Swap the profiling connection factory in the profiling service instance with our mock one
SwapProfilingConnectionFactoryHelper helper;
- MockProfilingConnectionFactory* mockProfilingConnectionFactory =
- boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory());
- BOOST_CHECK(mockProfilingConnectionFactory);
- MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection();
- BOOST_CHECK(mockProfilingConnection);
// Calculate the size of a Stream Metadata packet
std::string processName = GetProcessName().substr(0, 60);
@@ -2480,8 +2317,12 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
// Wait for a bit to make sure that we get the packet
std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
// Check that the mock profiling connection contains one Stream Metadata packet
- const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
@@ -2498,7 +2339,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
uint32_t header = ((packetFamily & 0x0000003F) << 26) |
((packetId & 0x000003FF) << 16);
- // Connection Acknowledged Packet
+ // Create the Connection Acknowledged Packet
Packet connectionAcknowledgedPacket(header);
// Write the packet to the mock profiling connection
@@ -2506,10 +2347,14 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
// Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
// the Connection Acknowledged packet gets processed by the profiling service
- std::this_thread::sleep_for(std::chrono::seconds(1));
+ std::this_thread::sleep_for(std::chrono::seconds(2));
// The Connection Acknowledged Command Handler should have updated the profiling state accordingly
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
}
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp
new file mode 100644
index 0000000000..3e6cf63efe
--- /dev/null
+++ b/src/profiling/test/ProfilingTests.hpp
@@ -0,0 +1,200 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "SendCounterPacketTests.hpp"
+
+#include <CommandHandlerFunctor.hpp>
+#include <IProfilingConnection.hpp>
+#include <IProfilingConnectionFactory.hpp>
+#include <Logging.hpp>
+#include <ProfilingService.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+#include <chrono>
+#include <iostream>
+#include <thread>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+struct LogLevelSwapper
+{
+public:
+ LogLevelSwapper(armnn::LogSeverity severity)
+ {
+ // Set the new log level
+ armnn::ConfigureLogging(true, true, severity);
+ }
+ ~LogLevelSwapper()
+ {
+ // The default log level for unit tests is "Fatal"
+ armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
+ }
+};
+
+struct CoutRedirect
+{
+public:
+ CoutRedirect(std::streambuf* newStreamBuffer)
+ : m_Old(std::cout.rdbuf(newStreamBuffer)) {}
+ ~CoutRedirect() { std::cout.rdbuf(m_Old); }
+
+private:
+ std::streambuf* m_Old;
+};
+
+struct StreamRedirector
+{
+public:
+ StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
+ : m_Stream(stream)
+ , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
+ {}
+ ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); }
+
+private:
+ std::ostream& m_Stream;
+ std::streambuf* m_BackupBuffer;
+};
+
+class TestProfilingConnectionBase : public IProfilingConnection
+{
+public:
+ TestProfilingConnectionBase() = default;
+ ~TestProfilingConnectionBase() = default;
+
+ bool IsOpen() const override { return true; }
+
+ void Close() override {}
+
+ bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
+
+ Packet ReadPacket(uint32_t timeout) override
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+ // Return connection acknowledged packet
+ std::unique_ptr<char[]> packetData;
+ return Packet(65536, 0, packetData);
+ }
+};
+
+class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
+{
+public:
+ TestProfilingConnectionTimeoutError()
+ : m_ReadRequests(0)
+ {}
+
+ Packet ReadPacket(uint32_t timeout) override
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+ if (m_ReadRequests < 3)
+ {
+ m_ReadRequests++;
+ throw armnn::TimeoutException("Simulate a timeout error\n");
+ }
+
+ // Return connection acknowledged packet after three timeouts
+ std::unique_ptr<char[]> packetData;
+ return Packet(65536, 0, packetData);
+ }
+
+private:
+ int m_ReadRequests;
+};
+
+class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
+{
+public:
+ Packet ReadPacket(uint32_t timeout) override
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+ throw armnn::Exception("Simulate a non-timeout error");
+ }
+};
+
+class TestFunctorA : public CommandHandlerFunctor
+{
+public:
+ using CommandHandlerFunctor::CommandHandlerFunctor;
+
+ int GetCount() { return m_Count; }
+
+ void operator()(const Packet& packet) override
+ {
+ m_Count++;
+ }
+
+private:
+ int m_Count = 0;
+};
+
+class TestFunctorB : public TestFunctorA
+{
+ using TestFunctorA::TestFunctorA;
+};
+
+class TestFunctorC : public TestFunctorA
+{
+ using TestFunctorA::TestFunctorA;
+};
+
+class MockProfilingConnectionFactory : public IProfilingConnectionFactory
+{
+public:
+ IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
+ {
+ return std::make_unique<MockProfilingConnection>();
+ }
+};
+
+class SwapProfilingConnectionFactoryHelper : public ProfilingService
+{
+public:
+ using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
+
+ SwapProfilingConnectionFactoryHelper()
+ : ProfilingService()
+ , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
+ , m_BackupProfilingConnectionFactory(nullptr)
+ {
+ BOOST_CHECK(m_MockProfilingConnectionFactory);
+ SwapProfilingConnectionFactory(ProfilingService::Instance(),
+ m_MockProfilingConnectionFactory.get(),
+ m_BackupProfilingConnectionFactory);
+ BOOST_CHECK(m_BackupProfilingConnectionFactory);
+ }
+ ~SwapProfilingConnectionFactoryHelper()
+ {
+ BOOST_CHECK(m_BackupProfilingConnectionFactory);
+ IProfilingConnectionFactory* temp = nullptr;
+ SwapProfilingConnectionFactory(ProfilingService::Instance(),
+ m_BackupProfilingConnectionFactory,
+ temp);
+ }
+
+ MockProfilingConnection* GetMockProfilingConnection()
+ {
+ IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
+ return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
+ }
+
+private:
+ MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
+ IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
+};
+
+} // namespace profiling
+
+} // namespace armnn
diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp
index 1216420383..00dad38078 100644
--- a/src/profiling/test/SendCounterPacketTests.cpp
+++ b/src/profiling/test/SendCounterPacketTests.cpp
@@ -2322,7 +2322,7 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1)
BOOST_TEST(reservedBuffer.get());
// Check that data was actually written to the profiling connection in any order
- const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
BOOST_TEST(writtenData.size() == 3);
bool foundStreamMetaDataPacket =
std::find(writtenData.begin(), writtenData.end(), streamMetadataPacketsize) != writtenData.end();
@@ -2391,7 +2391,7 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket3)
BOOST_CHECK_NO_THROW(sendCounterPacket.Stop());
// Check that the buffer contains one Stream Metadata packet
- const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
}
@@ -2420,7 +2420,7 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket4)
BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck));
// Check that the buffer contains one Stream Metadata packet
- const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp
index 48bab025dd..871ca74124 100644
--- a/src/profiling/test/SendCounterPacketTests.hpp
+++ b/src/profiling/test/SendCounterPacketTests.hpp
@@ -12,6 +12,7 @@
#include <armnn/Optional.hpp>
#include <armnn/Conversion.hpp>
+#include <boost/assert.hpp>
#include <boost/numeric/conversion/cast.hpp>
namespace armnn
@@ -19,6 +20,7 @@ namespace armnn
namespace profiling
{
+
class MockProfilingConnection : public IProfilingConnection
{
public:
@@ -28,9 +30,19 @@ public:
, m_Packet()
{}
- bool IsOpen() const override { return m_IsOpen; }
+ bool IsOpen() const override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ return m_IsOpen;
+ }
+
+ void Close() override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
- void Close() override { m_IsOpen = false; }
+ m_IsOpen = false;
+ }
bool WritePacket(const unsigned char* buffer, uint32_t length) override
{
@@ -39,11 +51,15 @@ public:
return false;
}
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
m_WrittenData.push_back(length);
return true;
}
bool WritePacket(Packet&& packet)
{
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
m_Packet = std::move(packet);
return true;
}
@@ -51,19 +67,32 @@ public:
Packet ReadPacket(uint32_t timeout) override
{
// Simulate a delay in the reading process
- std::this_thread::sleep_for(std::chrono::milliseconds(500));
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+ std::lock_guard<std::mutex> lock(m_Mutex);
return std::move(m_Packet);
}
- const std::vector<uint32_t>& GetWrittenData() const { return m_WrittenData; }
+ const std::vector<uint32_t> GetWrittenData() const
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ return m_WrittenData;
+ }
+
+ void Clear()
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
- void Clear() { m_WrittenData.clear(); }
+ m_WrittenData.clear();
+ }
private:
bool m_IsOpen;
std::vector<uint32_t> m_WrittenData;
Packet m_Packet;
+ mutable std::mutex m_Mutex;
};
class MockPacketBuffer : public IPacketBuffer
@@ -162,7 +191,7 @@ public:
IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
reservedSize = 0;
if (requestedSize > m_MaxBufferSize)
@@ -176,7 +205,7 @@ public:
void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
packetBuffer->Commit(size);
m_BufferList.push_back(std::move(packetBuffer));
@@ -185,14 +214,14 @@ public:
void Release(IPacketBufferPtr& packetBuffer) override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
packetBuffer->Release();
}
IPacketBufferPtr GetReadableBuffer() override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
if (m_BufferList.empty())
{
@@ -206,7 +235,7 @@ public:
void MarkRead(IPacketBufferPtr& packetBuffer) override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
m_ReadSize += packetBuffer->GetSize();
packetBuffer->MarkRead();