diff options
Diffstat (limited to 'tests/profiling/gatordmock/tests/GatordMockTests.cpp')
-rw-r--r-- | tests/profiling/gatordmock/tests/GatordMockTests.cpp | 228 |
1 files changed, 136 insertions, 92 deletions
diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp index 7d938bd404..7417946844 100644 --- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp +++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp @@ -9,12 +9,14 @@ #include <LabelsAndEventClasses.hpp> #include <PeriodicCounterCaptureCommandHandler.hpp> #include <ProfilingService.hpp> -#include <StreamMetadataCommandHandler.hpp> #include <TimelinePacketWriterFactory.hpp> #include <TimelineDirectoryCaptureCommandHandler.hpp> #include <TimelineDecoder.hpp> +#include <Runtime.hpp> +#include "../../src/backends/backendsCommon/test/MockBackend.hpp" + #include <boost/cast.hpp> #include <boost/test/test_tools.hpp> #include <boost/test/unit_test_suite.hpp> @@ -104,6 +106,19 @@ BOOST_AUTO_TEST_CASE(CounterCaptureHandlingTest) } } +void WaitFor(std::function<bool()> predicate, std::string errorMsg, uint32_t timeout = 2000, uint32_t sleepTime = 50) +{ + uint32_t timeSlept = 0; + while (!predicate()) + { + if (timeSlept >= timeout) + { + BOOST_FAIL("Timeout: " + errorMsg); + } + std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); + timeSlept += sleepTime; + } +} void CheckTimelineDirectory(timelinedecoder::TimelineDirectoryCaptureCommandHandler& commandHandler) { @@ -211,43 +226,6 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) // The purpose of this test is to setup both sides of the profiling service and get to the point of receiving // performance data. - //These variables are used to wait for the profiling service - uint32_t timeout = 2000; - uint32_t sleepTime = 50; - uint32_t timeSlept = 0; - - profiling::PacketVersionResolver packetVersionResolver; - - // Create the Command Handler Registry - profiling::CommandHandlerRegistry registry; - - timelinedecoder::TimelineDecoder timelineDecoder; - timelineDecoder.SetDefaultCallbacks(); - - // Update with derived functors - gatordmock::StreamMetadataCommandHandler streamMetadataCommandHandler( - 0, 0, packetVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true); - - gatordmock::PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler( - 0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), true); - - profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler( - 0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true); - - timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler( - 1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder); - - timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler( - 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), - timelineCaptureCommandHandler, true); - - // Register different derived functors - registry.RegisterFunctor(&streamMetadataCommandHandler); - registry.RegisterFunctor(&counterCaptureCommandHandler); - registry.RegisterFunctor(&directoryCaptureCommandHandler); - registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler); - registry.RegisterFunctor(&timelineCaptureCommandHandler); - // Setup the mock service to bind to the UDS. std::string udsNamespace = "gatord_namespace"; @@ -279,18 +257,15 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) BOOST_FAIL("Failed to connect client"); } - gatordmock::GatordMockService mockService(clientSocket, registry, false); + gatordmock::GatordMockService mockService(clientSocket, false); + + timelinedecoder::TimelineDecoder& timelineDecoder = mockService.GetTimelineDecoder(); + profiling::DirectoryCaptureCommandHandler& directoryCaptureCommandHandler = + mockService.GetDirectoryCaptureCommandHandler(); // Give the profiling service sending thread time start executing and send the stream metadata. - while (profilingService.GetCurrentState() != profiling::ProfilingState::WaitingForAck) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: Profiling service did not switch to WaitingForAck state"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } + WaitFor([&](){return profilingService.GetCurrentState() == profiling::ProfilingState::WaitingForAck;}, + "Profiling service did not switch to WaitingForAck state"); profilingService.Update(); // Read the stream metadata on the mock side. @@ -300,55 +275,21 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) } // Send Ack from GatorD mockService.SendConnectionAck(); + // And start to listen for packets + mockService.LaunchReceivingThread(); - timeSlept = 0; - while (profilingService.GetCurrentState() != profiling::ProfilingState::Active) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: Profiling service did not switch to Active state"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } + WaitFor([&](){return profilingService.GetCurrentState() == profiling::ProfilingState::Active;}, + "Profiling service did not switch to Active state"); - mockService.LaunchReceivingThread(); // As part of the default startup of the profiling service a counter directory packet will be sent. - timeSlept = 0; - while (!directoryCaptureCommandHandler.ParsedCounterDirectory()) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: MockGatord did not receive counter directory packet"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } + WaitFor([&](){return directoryCaptureCommandHandler.ParsedCounterDirectory();}, + "MockGatord did not receive counter directory packet"); - // As part of the default startup of the profiling service a counter directory packet will be sent. - timeSlept = 0; - while (!directoryCaptureCommandHandler.ParsedCounterDirectory()) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: MockGatord did not receive counter directory packet"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } + // Following that we will receive a collection of well known timeline labels and event classes + WaitFor([&](){return timelineDecoder.GetModel().m_EventClasses.size() >= 2;}, + "MockGatord did not receive well known timeline labels and event classes"); - timeSlept = 0; - while (timelineDecoder.GetModel().m_EventClasses.size() < 2) - { - if (timeSlept >= timeout) - { - BOOST_FAIL("Timeout: MockGatord did not receive well known timeline labels"); - } - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); - timeSlept += sleepTime; - } - - CheckTimelineDirectory(timelineDirectoryCaptureCommandHandler); + CheckTimelineDirectory(mockService.GetTimelineDirectoryCaptureCommandHandler()); // Verify the commonly used timeline packets sent when the profiling service enters the active state CheckTimelinePackets(timelineDecoder); @@ -439,4 +380,107 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd) // PeriodicCounterCapture data received. These are yet to be integrated. } +BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation) +{ + armnn::MockBackendInitialiser initialiser; + // Setup the mock service to bind to the UDS. + std::string udsNamespace = "gatord_namespace"; + + armnnUtils::Sockets::Initialize(); + armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); + + if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace)) + { + BOOST_FAIL("Failed to open Listening Socket"); + } + + armnn::IRuntime::CreationOptions options; + options.m_ProfilingOptions.m_EnableProfiling = true; + armnn::Runtime runtime(options); + + armnnUtils::Sockets::Socket clientConnection; + clientConnection = armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC); + gatordmock::GatordMockService mockService(clientConnection, false); + + // Read the stream metadata on the mock side. + if (!mockService.WaitForStreamMetaData()) + { + BOOST_FAIL("Failed to receive StreamMetaData"); + } + + armnn::MockBackendProfilingService mockProfilingService = armnn::MockBackendProfilingService::Instance(); + armnn::MockBackendProfilingContext *mockBackEndProfilingContext = mockProfilingService.GetContext(); + + // Send Ack from GatorD + mockService.SendConnectionAck(); + // And start to listen for packets + mockService.LaunchReceivingThread(); + + // Build and optimize a simple network while we wait + INetworkPtr net(INetwork::Create()); + + IConnectableLayer* input = net->AddInputLayer(0, "input"); + + NormalizationDescriptor descriptor; + IConnectableLayer* normalize = net->AddNormalizationLayer(descriptor, "normalization"); + + IConnectableLayer* output = net->AddOutputLayer(0, "output"); + + input->GetOutputSlot(0).Connect(normalize->GetInputSlot(0)); + normalize->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 4, 4 }, DataType::Float32)); + normalize->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 4, 4 }, DataType::Float32)); + + std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef }; + IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime.GetDeviceSpec()); + + WaitFor([&](){return mockService.GetDirectoryCaptureCommandHandler().ParsedCounterDirectory();}, + "MockGatord did not receive counter directory packet"); + + timelinedecoder::TimelineDecoder& timelineDecoder = mockService.GetTimelineDecoder(); + + WaitFor([&](){return timelineDecoder.GetModel().m_EventClasses.size() >= 2;}, + "MockGatord did not receive well known timeline labels"); + + // Packets we expect from SendWellKnownLabelsAndEventClassesTest + BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 0); + BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size() == 2); + BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size() == 10); + BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size() == 0); + BOOST_CHECK(timelineDecoder.GetModel().m_Events.size() == 0); + + mockService.SendDeactivateTimelinePacket(); + + WaitFor([&](){return !mockBackEndProfilingContext->TimelineReportingEnabled();}, + "Timeline packets were not deactivated"); + + // Load the network into runtime now that timeline reporting is disabled + armnn::NetworkId netId; + runtime.LoadNetwork(netId, std::move(optNet)); + + // Now activate timeline packets + mockService.SendActivateTimelinePacket(); + + WaitFor([&](){return mockBackEndProfilingContext->TimelineReportingEnabled();}, + "Timeline packets were not activated"); + + // Once timeline packets have been reactivated the ActivateTimelineReportingCommandHandler will resend the + // SendWellKnownLabelsAndEventClasses and then send the structure of any loaded networks + WaitFor([&](){return timelineDecoder.GetModel().m_Labels.size() >= 24;}, + "MockGatord did not receive well known timeline labels"); + + // Packets we expect from SendWellKnownLabelsAndEventClassesTest * 2 and the loaded model + BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 5); + BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size() == 4); + BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size() == 24); + BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size() == 28); + BOOST_CHECK(timelineDecoder.GetModel().m_Events.size() == 0); + + mockService.WaitForReceivingThread(); + armnnUtils::Sockets::Close(listeningSocket); + + GetProfilingService(&runtime).Disconnect(); +} + BOOST_AUTO_TEST_SUITE_END() |