aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-02-27 16:21:41 +0000
committerJim Flynn <jim.flynn@arm.com>2020-03-18 12:59:19 +0000
commite6a2ccd09060ba93203ddc5a7f79260cedf2c147 (patch)
treec542464311f16acde42fbe01df9f4d8a78feff64
parenteba482e691bb314e1379d29f267ec3b46a082d01 (diff)
downloadarmnn-e6a2ccd09060ba93203ddc5a7f79260cedf2c147.tar.gz
IVGCVSW-4161 Provide for per model call back registration
!armnn:2810 Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: Idf56d42bd767baa5df0059a2f489f75281f8ac71
-rw-r--r--CMakeLists.txt3
-rw-r--r--src/timelineDecoder/TimelineCaptureCommandHandler.cpp17
-rw-r--r--src/timelineDecoder/TimelineCaptureCommandHandler.hpp17
-rw-r--r--src/timelineDecoder/TimelineDecoder.cpp33
-rw-r--r--src/timelineDecoder/TimelineDecoder.hpp3
-rw-r--r--src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp3
-rw-r--r--src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp6
-rw-r--r--src/timelineDecoder/tests/TimelineTests.cpp8
-rw-r--r--tests/profiling/gatordmock/GatordMockMain.cpp122
-rw-r--r--tests/profiling/gatordmock/GatordMockService.cpp25
-rw-r--r--tests/profiling/gatordmock/GatordMockService.hpp25
-rw-r--r--tests/profiling/gatordmock/tests/GatordMockTests.cpp32
12 files changed, 208 insertions, 86 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d5da0d3aad..3d0f518fe2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1002,7 +1002,7 @@ if(BUILD_GATORD_MOCK)
tests/profiling/gatordmock/StreamMetadataCommandHandler.hpp
)
- include_directories(src/profiling tests/profiling tests/profiling/gatordmock)
+ include_directories(src/profiling tests/profiling tests/profiling/gatordmock src/timelineDecoder)
add_library_ex(gatordMockService STATIC ${gatord_mock_sources})
target_include_directories(gatordMockService PRIVATE src/armnnUtils)
@@ -1012,6 +1012,7 @@ if(BUILD_GATORD_MOCK)
target_link_libraries(GatordMock
armnn
+ timelineDecoder
gatordMockService
${Boost_PROGRAM_OPTIONS_LIBRARY}
${Boost_SYSTEM_LIBRARY})
diff --git a/src/timelineDecoder/TimelineCaptureCommandHandler.cpp b/src/timelineDecoder/TimelineCaptureCommandHandler.cpp
index fb6935e247..58edd9fc43 100644
--- a/src/timelineDecoder/TimelineCaptureCommandHandler.cpp
+++ b/src/timelineDecoder/TimelineCaptureCommandHandler.cpp
@@ -6,7 +6,7 @@
#include "TimelineCaptureCommandHandler.hpp"
#include <string>
-
+#include <armnn/Logging.hpp>
namespace armnn
{
@@ -28,7 +28,15 @@ void TimelineCaptureCommandHandler::ParseData(const armnn::profiling::Packet& pa
uint32_t offset = 0;
m_PacketLength = packet.GetLength();
- if ( m_PacketLength < 8 )
+ // We are expecting TimelineDirectoryCaptureCommandHandler to set the thread id size
+ // if it not set in the constructor
+ if (m_ThreadIdSize == 0)
+ {
+ ARMNN_LOG(error) << "TimelineCaptureCommandHandler: m_ThreadIdSize has not been set";
+ return;
+ }
+
+ if (packet.GetLength() < 8)
{
return;
}
@@ -125,6 +133,11 @@ void TimelineCaptureCommandHandler::ReadEvent(const unsigned char* data, uint32_
m_TimelineDecoder.CreateEvent(event);
}
+void TimelineCaptureCommandHandler::SetThreadIdSize(uint32_t size)
+{
+ m_ThreadIdSize = size;
+}
+
void TimelineCaptureCommandHandler::operator()(const profiling::Packet& packet)
{
ParseData(packet);
diff --git a/src/timelineDecoder/TimelineCaptureCommandHandler.hpp b/src/timelineDecoder/TimelineCaptureCommandHandler.hpp
index b69e615b56..e143b5f6e5 100644
--- a/src/timelineDecoder/TimelineCaptureCommandHandler.hpp
+++ b/src/timelineDecoder/TimelineCaptureCommandHandler.hpp
@@ -5,9 +5,9 @@
#pragma once
-#include <CommandHandlerFunctor.hpp>
#include "armnn/profiling/ITimelineDecoder.hpp"
+#include <CommandHandlerFunctor.hpp>
#include <Packet.hpp>
#include <ProfilingUtils.hpp>
@@ -31,11 +31,11 @@ public:
uint32_t packetId,
uint32_t version,
ITimelineDecoder& timelineDecoder,
- uint32_t threadId_size)
- : CommandHandlerFunctor(familyId, packetId, version),
- m_TimelineDecoder(timelineDecoder),
- m_ThreadIdSize(threadId_size),
- m_PacketLength(0)
+ uint32_t threadIdSize = 0)
+ : CommandHandlerFunctor(familyId, packetId, version)
+ , m_TimelineDecoder(timelineDecoder)
+ , m_ThreadIdSize(threadIdSize)
+ , m_PacketLength(0)
{}
void operator()(const armnn::profiling::Packet& packet) override;
@@ -46,12 +46,13 @@ public:
void ReadRelationship(const unsigned char* data, uint32_t& offset);
void ReadEvent(const unsigned char* data, uint32_t& offset);
+ void SetThreadIdSize(uint32_t size);
+
private:
void ParseData(const armnn::profiling::Packet& packet);
ITimelineDecoder& m_TimelineDecoder;
-
- const uint32_t m_ThreadIdSize;
+ uint32_t m_ThreadIdSize;
unsigned int m_PacketLength;
static const ReadFunction m_ReadFunctions[];
diff --git a/src/timelineDecoder/TimelineDecoder.cpp b/src/timelineDecoder/TimelineDecoder.cpp
index 2f9ac135b4..f7f4663530 100644
--- a/src/timelineDecoder/TimelineDecoder.cpp
+++ b/src/timelineDecoder/TimelineDecoder.cpp
@@ -4,13 +4,14 @@
//
#include "TimelineDecoder.hpp"
-#include "../profiling/ProfilingUtils.hpp"
-
+#include <ProfilingUtils.hpp>
#include <iostream>
+
namespace armnn
{
namespace timelinedecoder
{
+
TimelineDecoder::TimelineStatus TimelineDecoder::CreateEntity(const Entity &entity)
{
if (m_OnNewEntityCallback == nullptr)
@@ -120,6 +121,34 @@ TimelineDecoder::TimelineStatus TimelineDecoder::SetRelationshipCallback(OnNewRe
return TimelineStatus::TimelineStatus_Success;
}
+void TimelineDecoder::SetDefaultCallbacks()
+{
+ SetEntityCallback([](Model& model, const ITimelineDecoder::Entity entity)
+ {
+ model.m_Entities.emplace_back(entity);
+ });
+
+ SetEventClassCallback([](Model& model, const ITimelineDecoder::EventClass eventClass)
+ {
+ model.m_EventClasses.emplace_back(eventClass);
+ });
+
+ SetEventCallback([](Model& model, const ITimelineDecoder::Event event)
+ {
+ model.m_Events.emplace_back(event);
+ });
+
+ SetLabelCallback([](Model& model, const ITimelineDecoder::Label label)
+ {
+ model.m_Labels.emplace_back(label);
+ });
+
+ SetRelationshipCallback([](Model& model, const ITimelineDecoder::Relationship relationship)
+ {
+ model.m_Relationships.emplace_back(relationship);
+ });
+}
+
void TimelineDecoder::print()
{
printLabels();
diff --git a/src/timelineDecoder/TimelineDecoder.hpp b/src/timelineDecoder/TimelineDecoder.hpp
index 405673164b..c6d1e4ee0a 100644
--- a/src/timelineDecoder/TimelineDecoder.hpp
+++ b/src/timelineDecoder/TimelineDecoder.hpp
@@ -39,13 +39,14 @@ public:
const Model& GetModel();
-
TimelineStatus SetEntityCallback(const OnNewEntityCallback);
TimelineStatus SetEventClassCallback(const OnNewEventClassCallback);
TimelineStatus SetEventCallback(const OnNewEventCallback);
TimelineStatus SetLabelCallback(const OnNewLabelCallback);
TimelineStatus SetRelationshipCallback(const OnNewRelationshipCallback);
+ void SetDefaultCallbacks();
+
void print();
private:
diff --git a/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp b/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp
index 655e461b8c..74aefea142 100644
--- a/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp
+++ b/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp
@@ -4,6 +4,7 @@
//
#include "TimelineDirectoryCaptureCommandHandler.hpp"
+#include "TimelineCaptureCommandHandler.hpp"
#include <iostream>
#include <string>
@@ -41,6 +42,8 @@ void TimelineDirectoryCaptureCommandHandler::ParseData(const armnn::profiling::P
{
m_SwTraceMessages.push_back(profiling::ReadSwTraceMessage(data, offset));
}
+
+ m_TimelineCaptureCommandHandler.SetThreadIdSize(m_SwTraceHeader.m_ThreadIdBytes);
}
void TimelineDirectoryCaptureCommandHandler::Print()
diff --git a/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp b/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp
index b4e0fd2d7d..a22a5d9f87 100644
--- a/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp
+++ b/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp
@@ -5,7 +5,8 @@
#pragma once
-#include <CommandHandlerFunctor.hpp>
+
+#include <TimelineCaptureCommandHandler.hpp>
#include <Packet.hpp>
#include <PacketBuffer.hpp>
#include <ProfilingUtils.hpp>
@@ -26,8 +27,10 @@ public:
TimelineDirectoryCaptureCommandHandler(uint32_t familyId,
uint32_t packetId,
uint32_t version,
+ TimelineCaptureCommandHandler& timelineCaptureCommandHandler,
bool quietOperation = false)
: CommandHandlerFunctor(familyId, packetId, version)
+ , m_TimelineCaptureCommandHandler(timelineCaptureCommandHandler)
, m_QuietOperation(quietOperation)
{}
@@ -40,6 +43,7 @@ private:
void ParseData(const armnn::profiling::Packet& packet);
void Print();
+ TimelineCaptureCommandHandler& m_TimelineCaptureCommandHandler;
bool m_QuietOperation;
};
diff --git a/src/timelineDecoder/tests/TimelineTests.cpp b/src/timelineDecoder/tests/TimelineTests.cpp
index 62b4330e1f..1f55515758 100644
--- a/src/timelineDecoder/tests/TimelineTests.cpp
+++ b/src/timelineDecoder/tests/TimelineTests.cpp
@@ -83,8 +83,13 @@ BOOST_AUTO_TEST_CASE(TimelineDirectoryTest)
profiling::PacketVersionResolver packetVersionResolver;
+ TimelineDecoder timelineDecoder;
+ TimelineCaptureCommandHandler timelineCaptureCommandHandler(
+ 1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder);
+
TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler(
- 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), true);
+ 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
+ timelineCaptureCommandHandler, true);
sendTimelinePacket->SendTimelineMessageDirectoryPackage();
sendTimelinePacket->Commit();
@@ -151,6 +156,7 @@ BOOST_AUTO_TEST_CASE(TimelineCaptureTest)
TimelineDecoder timelineDecoder;
const TimelineDecoder::Model& model = timelineDecoder.GetModel();
+
TimelineCaptureCommandHandler timelineCaptureCommandHandler(
1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder, threadIdSize);
diff --git a/tests/profiling/gatordmock/GatordMockMain.cpp b/tests/profiling/gatordmock/GatordMockMain.cpp
index edad85cffe..e19461f6cb 100644
--- a/tests/profiling/gatordmock/GatordMockMain.cpp
+++ b/tests/profiling/gatordmock/GatordMockMain.cpp
@@ -3,60 +3,67 @@
// SPDX-License-Identifier: MIT
//
-#include "../../../src/profiling/PacketVersionResolver.hpp"
-#include "../../../src/profiling/PeriodicCounterSelectionCommandHandler.hpp"
+#include "PacketVersionResolver.hpp"
#include "CommandFileParser.hpp"
#include "CommandLineProcessor.hpp"
#include "DirectoryCaptureCommandHandler.hpp"
#include "GatordMockService.hpp"
#include "PeriodicCounterCaptureCommandHandler.hpp"
#include "PeriodicCounterSelectionResponseHandler.hpp"
+#include <TimelineDecoder.hpp>
+#include <TimelineDirectoryCaptureCommandHandler.hpp>
+#include <TimelineCaptureCommandHandler.hpp>
#include <iostream>
#include <string>
+#include <NetworkSockets.hpp>
+#include <signal.h>
-int main(int argc, char* argv[])
+using namespace armnn;
+using namespace gatordmock;
+
+// Used to capture ctrl-c so we can close any remaining sockets before exit
+static volatile bool run = true;
+void exit_capture(int signum)
{
- // Process command line arguments
- armnn::gatordmock::CommandLineProcessor cmdLine;
- if (!cmdLine.ProcessCommandLine(argc, argv))
- {
- return EXIT_FAILURE;
- }
+ IgnoreUnused(signum);
+ run = false;
+}
- armnn::profiling::PacketVersionResolver packetVersionResolver;
+bool CreateMockService(armnnUtils::Sockets::Socket clientConnection, std::string commandFile, bool isEchoEnabled)
+{
+ profiling::PacketVersionResolver packetVersionResolver;
// Create the Command Handler Registry
- armnn::profiling::CommandHandlerRegistry registry;
+ profiling::CommandHandlerRegistry registry;
+
+ timelinedecoder::TimelineDecoder timelineDecoder;
+ timelineDecoder.SetDefaultCallbacks();
// This functor will receive back the selection response packet.
- armnn::gatordmock::PeriodicCounterSelectionResponseHandler periodicCounterSelectionResponseHandler(
- 0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue());
+ PeriodicCounterSelectionResponseHandler periodicCounterSelectionResponseHandler(
+ 0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue());
// This functor will receive the counter data.
- armnn::gatordmock::PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler(
- 3, 0, packetVersionResolver.ResolvePacketVersion(3, 0).GetEncodedValue());
+ PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler(
+ 3, 0, packetVersionResolver.ResolvePacketVersion(3, 0).GetEncodedValue());
+
+ profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
+ 0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), false);
+
+ timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler(
+ 1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder);
- armnn::profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
- 0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), false);
+ timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler(
+ 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
+ timelineCaptureCommandHandler, false);
// Register different derived functors
registry.RegisterFunctor(&periodicCounterSelectionResponseHandler);
registry.RegisterFunctor(&counterCaptureCommandHandler);
registry.RegisterFunctor(&directoryCaptureCommandHandler);
+ registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler);
+ registry.RegisterFunctor(&timelineCaptureCommandHandler);
- armnn::gatordmock::GatordMockService mockService(registry, cmdLine.IsEchoEnabled());
-
- if (!mockService.OpenListeningSocket(cmdLine.GetUdsNamespace()))
- {
- return EXIT_FAILURE;
- }
- std::cout << "Bound to UDS namespace: \\0" << cmdLine.GetUdsNamespace() << std::endl;
-
- // Wait for a single connection.
- if (-1 == mockService.BlockForOneClient())
- {
- return EXIT_FAILURE;
- }
- std::cout << "Client connection established." << std::endl;
+ GatordMockService mockService(clientConnection, registry, isEchoEnabled);
// Send receive the strweam metadata and send connection ack.
if (!mockService.WaitForStreamMetaData())
@@ -69,11 +76,60 @@ int main(int argc, char* argv[])
mockService.LaunchReceivingThread();
// Process the SET and WAIT command from the file.
- armnn::gatordmock::CommandFileParser commandLineParser;
- commandLineParser.ParseFile(cmdLine.GetCommandFile(), mockService);
+ CommandFileParser commandLineParser;
+ commandLineParser.ParseFile(commandFile, mockService);
// Once we've finished processing the file wait for the receiving thread to close.
mockService.WaitForReceivingThread();
+ if(isEchoEnabled)
+ {
+ timelineDecoder.print();
+ }
+
return EXIT_SUCCESS;
}
+
+int main(int argc, char* argv[])
+{
+ // We need to capture ctrl-c so we can close any remaining sockets before exit
+ signal(SIGINT, exit_capture);
+
+ // Process command line arguments
+ CommandLineProcessor cmdLine;
+ if (!cmdLine.ProcessCommandLine(argc, argv))
+ {
+ return EXIT_FAILURE;
+ }
+
+ std::vector<std::thread> threads;
+ std::string commandFile = cmdLine.GetCommandFile();
+
+ armnnUtils::Sockets::Initialize();
+ armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
+
+ if (!GatordMockService::OpenListeningSocket(listeningSocket, cmdLine.GetUdsNamespace(), 10))
+ {
+ return EXIT_FAILURE;
+ }
+ std::cout << "Bound to UDS namespace: \\0" << cmdLine.GetUdsNamespace() << std::endl;
+
+ // make the socket non-blocking so we can exit the loop
+ armnnUtils::Sockets::SetNonBlocking(listeningSocket);
+ while (run)
+ {
+ armnnUtils::Sockets::Socket clientConnection =
+ armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
+
+ if (clientConnection > 0)
+ {
+ threads.emplace_back(
+ std::thread(CreateMockService, clientConnection, commandFile, cmdLine.IsEchoEnabled()));
+ }
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(100u));
+ }
+
+ armnnUtils::Sockets::Close(listeningSocket);
+ std::for_each(threads.begin(), threads.end(), [](std::thread& t){t.join();});
+} \ No newline at end of file
diff --git a/tests/profiling/gatordmock/GatordMockService.cpp b/tests/profiling/gatordmock/GatordMockService.cpp
index c5211962d3..a3f732cb55 100644
--- a/tests/profiling/gatordmock/GatordMockService.cpp
+++ b/tests/profiling/gatordmock/GatordMockService.cpp
@@ -24,11 +24,11 @@ namespace armnn
namespace gatordmock
{
-bool GatordMockService::OpenListeningSocket(std::string udsNamespace)
+bool GatordMockService::OpenListeningSocket(armnnUtils::Sockets::Socket listeningSocket,
+ const std::string udsNamespace,
+ const int numOfConnections)
{
- Sockets::Initialize();
- m_ListeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
- if (-1 == m_ListeningSocket)
+ if (-1 == listeningSocket)
{
std::cerr << ": Socket construction failed: " << strerror(errno) << std::endl;
return false;
@@ -41,13 +41,13 @@ bool GatordMockService::OpenListeningSocket(std::string udsNamespace)
udsAddress.sun_family = AF_UNIX;
// Bind the socket to the UDS namespace.
- if (-1 == bind(m_ListeningSocket, reinterpret_cast<const sockaddr*>(&udsAddress), sizeof(sockaddr_un)))
+ if (-1 == bind(listeningSocket, reinterpret_cast<const sockaddr*>(&udsAddress), sizeof(sockaddr_un)))
{
std::cerr << ": Binding on socket failed: " << strerror(errno) << std::endl;
return false;
}
- // Listen for 1 connection.
- if (-1 == listen(m_ListeningSocket, 1))
+ // Listen for 10 connections.
+ if (-1 == listen(listeningSocket, numOfConnections))
{
std::cerr << ": Listen call on socket failed: " << strerror(errno) << std::endl;
return false;
@@ -55,17 +55,6 @@ bool GatordMockService::OpenListeningSocket(std::string udsNamespace)
return true;
}
-Sockets::Socket GatordMockService::BlockForOneClient()
-{
- m_ClientConnection = Sockets::Accept(m_ListeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
- if (-1 == m_ClientConnection)
- {
- std::cerr << ": Failure when waiting for a client connection: " << strerror(errno) << std::endl;
- return -1;
- }
- return m_ClientConnection;
-}
-
bool GatordMockService::WaitForStreamMetaData()
{
if (m_EchoPackets)
diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp
index f91e902db8..c00685fff2 100644
--- a/tests/profiling/gatordmock/GatordMockService.hpp
+++ b/tests/profiling/gatordmock/GatordMockService.hpp
@@ -39,10 +39,13 @@ class GatordMockService
public:
/// @param registry reference to a command handler registry.
/// @param echoPackets if true the raw packets will be printed to stdout.
- GatordMockService(armnn::profiling::CommandHandlerRegistry& registry, bool echoPackets)
- : m_HandlerRegistry(registry)
- , m_EchoPackets(echoPackets)
- , m_CloseReceivingThread(false)
+ GatordMockService(armnnUtils::Sockets::Socket clientConnection,
+ armnn::profiling::CommandHandlerRegistry& registry,
+ bool echoPackets)
+ : m_ClientConnection(clientConnection)
+ , m_HandlerRegistry(registry)
+ , m_EchoPackets(echoPackets)
+ , m_CloseReceivingThread(false)
{
m_PacketsReceivedCount.store(0, std::memory_order_relaxed);
}
@@ -51,17 +54,14 @@ public:
{
// We have set SOCK_CLOEXEC on these sockets but we'll close them to be good citizens.
armnnUtils::Sockets::Close(m_ClientConnection);
- armnnUtils::Sockets::Close(m_ListeningSocket);
}
/// Establish the Unix domain socket and set it to listen for connections.
/// @param udsNamespace the namespace (socket address) associated with the listener.
/// @return true only if the socket has been correctly setup.
- bool OpenListeningSocket(std::string udsNamespace);
-
- /// Block waiting to accept one client to connect to the UDS.
- /// @return the file descriptor of the client connection.
- armnnUtils::Sockets::Socket BlockForOneClient();
+ static bool OpenListeningSocket(armnnUtils::Sockets::Socket listeningSocket,
+ const std::string udsNamespace,
+ const int numOfConnections = 1);
/// Once the connection is open wait to receive the stream meta data packet from the client. Reading this
/// packet differs from others as we need to determine endianness.
@@ -118,6 +118,8 @@ public:
private:
void ReceiveLoop(GatordMockService& mockService);
+ int MainLoop(armnn::profiling::CommandHandlerRegistry& registry, armnnUtils::Sockets::Socket m_ClientConnection);
+
/// Block on the client connection until a complete packet has been received. This is a placeholder function to
/// enable early testing of the tool.
/// @return true if a valid packet has been received.
@@ -145,11 +147,10 @@ private:
uint32_t m_StreamMetaDataMaxDataLen;
uint32_t m_StreamMetaDataPid;
+ armnnUtils::Sockets::Socket m_ClientConnection;
armnn::profiling::CommandHandlerRegistry& m_HandlerRegistry;
bool m_EchoPackets;
- armnnUtils::Sockets::Socket m_ListeningSocket;
- armnnUtils::Sockets::Socket m_ClientConnection;
std::thread m_ListeningThread;
std::atomic<bool> m_CloseReceivingThread;
};
diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp
index 78c6f117ac..bba848588e 100644
--- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp
+++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp
@@ -11,6 +11,7 @@
#include <StreamMetadataCommandHandler.hpp>
#include <TimelineDirectoryCaptureCommandHandler.hpp>
+#include <TimelineDecoder.hpp>
#include <test/ProfilingMocks.hpp>
@@ -21,7 +22,7 @@
BOOST_AUTO_TEST_SUITE(GatordMockTests)
using namespace armnn;
-using namespace std::this_thread; // sleep_for, sleep_until
+using namespace std::this_thread;
using namespace std::chrono_literals;
BOOST_AUTO_TEST_CASE(CounterCaptureHandlingTest)
@@ -118,6 +119,9 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
// Create the Command Handler Registry
profiling::CommandHandlerRegistry registry;
+ timelinedecoder::TimelineDecoder timelineDecoder;
+ timelineDecoder.SetDefaultCallbacks();
+
// Update with derived functors
gatordmock::StreamMetadataCommandHandler streamMetadataCommandHandler(
0, 0, packetVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true);
@@ -128,18 +132,29 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true);
+ timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler(
+ 1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder);
+
timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler(
- 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), true);
+ 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
+ timelineCaptureCommandHandler, true);
// Register different derived functors
registry.RegisterFunctor(&streamMetadataCommandHandler);
registry.RegisterFunctor(&counterCaptureCommandHandler);
registry.RegisterFunctor(&directoryCaptureCommandHandler);
registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler);
+
// Setup the mock service to bind to the UDS.
std::string udsNamespace = "gatord_namespace";
- gatordmock::GatordMockService mockService(registry, false);
- mockService.OpenListeningSocket(udsNamespace);
+
+ armnnUtils::Sockets::Initialize();
+ armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
+
+ if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace))
+ {
+ BOOST_FAIL("Failed to open Listening Socket");
+ }
// Enable the profiling service.
armnn::IRuntime::CreationOptions::ExternalProfilingOptions options;
@@ -154,12 +169,15 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
profilingService.Update();
// Connect the profiling service to the mock Gatord.
- int clientFd = mockService.BlockForOneClient();
- if (-1 == clientFd)
+ armnnUtils::Sockets::Socket clientSocket =
+ armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
+ if (-1 == clientSocket)
{
BOOST_FAIL("Failed to connect client");
}
+ gatordmock::GatordMockService mockService(clientSocket, registry, false);
+
// Give the profiling service sending thread time start executing and send the stream metadata.
while (profilingService.GetCurrentState() != profiling::ProfilingState::WaitingForAck)
{
@@ -286,7 +304,7 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
mockService.WaitForReceivingThread();
options.m_EnableProfiling = false;
profilingService.ResetExternalProfilingOptions(options, true);
-
+ armnnUtils::Sockets::Close(listeningSocket);
// Future tests here will add counters to the ProfilingService, increment values and examine
// PeriodicCounterCapture data received. These are yet to be integrated.
}