aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2020-03-30 10:43:41 +0100
committerFinn Williams <Finn.Williams@arm.com>2020-04-02 16:56:24 +0100
commit33ed221e2e8e3a77b5f543061e0cce07b259fc64 (patch)
tree8467f2e4ce019bfa2837ae1030c321509414780c
parent0fe279bbf22fd2116b283b9df61076d526f115e4 (diff)
downloadarmnn-33ed221e2e8e3a77b5f543061e0cce07b259fc64.tar.gz
IVGCVSW-4455 Add an Activate and Deactivate control packet to the protocol
* Add Activate/Deactivate command handlers * Add IReportStructure, INotifyBackends single function interfaces * Add overrided mechanism to report structure in Runtime.cpp * Add overrided mechanism to notify backends in ProfilingService.cpp * Add optional IReportStructure argument to ProfilingService constructor for use in ActivateTimelineReportingCommandHandler * Refactoring and tidying up indentation * Removal of unused code in ProfilingUtils.cpp and ProfilingService.cpp * Added GatordMock end to end test * Fixed an issue with SendCounterPacket sending duplicate packets * Fixed an issue with DirectoryCaptureCommandHandler handling of Optional Signed-off-by: Keith Davis <keith.davis@arm.com> Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I5ef1b74171459bfc649861dedf99921d22c9e63f
-rw-r--r--Android.mk2
-rw-r--r--CMakeLists.txt6
-rw-r--r--include/armnn/IRuntime.hpp2
-rw-r--r--include/armnn/backends/profiling/IBackendProfilingContext.hpp1
-rw-r--r--src/armnn/LoadedNetwork.cpp37
-rw-r--r--src/armnn/LoadedNetwork.hpp2
-rw-r--r--src/armnn/Runtime.cpp23
-rw-r--r--src/armnn/Runtime.hpp16
-rw-r--r--src/armnn/test/RuntimeTests.cpp9
-rw-r--r--src/backends/backendsCommon/test/MockBackend.cpp1
-rw-r--r--src/backends/backendsCommon/test/MockBackend.hpp13
-rw-r--r--src/profiling/ActivateTimelineReportingCommandHandler.cpp63
-rw-r--r--src/profiling/ActivateTimelineReportingCommandHandler.hpp54
-rw-r--r--src/profiling/CommandHandler.hpp30
-rw-r--r--src/profiling/DeactivateTimelineReportingCommandHandler.cpp53
-rw-r--r--src/profiling/DeactivateTimelineReportingCommandHandler.hpp45
-rw-r--r--src/profiling/DirectoryCaptureCommandHandler.cpp3
-rw-r--r--src/profiling/DirectoryCaptureCommandHandler.hpp2
-rw-r--r--src/profiling/INotifyBackends.hpp24
-rw-r--r--src/profiling/IReportStructure.hpp24
-rw-r--r--src/profiling/PacketVersionResolver.cpp13
-rw-r--r--src/profiling/PacketVersionResolver.hpp3
-rw-r--r--src/profiling/ProfilingService.cpp16
-rw-r--r--src/profiling/ProfilingService.hpp104
-rw-r--r--src/profiling/ProfilingUtils.cpp13
-rw-r--r--src/profiling/ProfilingUtils.hpp2
-rw-r--r--src/profiling/SendCounterPacket.cpp41
-rw-r--r--src/profiling/TimelineUtilityMethods.cpp2
-rw-r--r--src/profiling/test/ProfilingMocks.hpp15
-rw-r--r--src/profiling/test/ProfilingTestUtils.cpp8
-rw-r--r--src/profiling/test/ProfilingTestUtils.hpp6
-rw-r--r--src/profiling/test/ProfilingTests.cpp130
-rw-r--r--src/profiling/test/ProfilingTests.hpp2
-rw-r--r--src/profiling/test/SendCounterPacketTests.cpp15
-rw-r--r--src/profiling/test/SendTimelinePacketTests.cpp11
-rw-r--r--tests/profiling/gatordmock/GatordMockMain.cpp44
-rw-r--r--tests/profiling/gatordmock/GatordMockService.cpp27
-rw-r--r--tests/profiling/gatordmock/GatordMockService.hpp75
-rw-r--r--tests/profiling/gatordmock/tests/GatordMockTests.cpp228
39 files changed, 918 insertions, 247 deletions
diff --git a/Android.mk b/Android.mk
index 87b1f9ac1a..6723debbe3 100644
--- a/Android.mk
+++ b/Android.mk
@@ -184,6 +184,7 @@ LOCAL_SRC_FILES := \
src/armnn/layers/SwitchLayer.cpp \
src/armnn/layers/TransposeConvolution2dLayer.cpp \
src/armnn/layers/TransposeLayer.cpp \
+ src/profiling/ActivateTimelineReportingCommandHandler.cpp \
src/profiling/BufferManager.cpp \
src/profiling/CommandHandler.cpp \
src/profiling/CommandHandlerFunctor.cpp \
@@ -193,6 +194,7 @@ LOCAL_SRC_FILES := \
src/profiling/CounterDirectory.cpp \
src/profiling/CounterIdMap.cpp \
src/profiling/DirectoryCaptureCommandHandler.cpp \
+ src/profiling/DeactivateTimelineReportingCommandHandler.cpp \
src/profiling/FileOnlyProfilingConnection.cpp \
src/profiling/Holder.cpp \
src/profiling/LabelsAndEventClasses.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c093344acd..aac18e99cb 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -462,6 +462,8 @@ list(APPEND armnn_sources
src/armnn/optimizations/PermuteAndBatchToSpaceAsDepthToSpace.hpp
src/armnn/optimizations/PermuteAsReshape.hpp
src/armnn/optimizations/SquashEqualSiblings.hpp
+ src/profiling/ActivateTimelineReportingCommandHandler.cpp
+ src/profiling/ActivateTimelineReportingCommandHandler.hpp
src/profiling/BufferManager.cpp
src/profiling/BufferManager.hpp
src/profiling/CommandHandler.cpp
@@ -478,6 +480,8 @@ list(APPEND armnn_sources
src/profiling/CounterDirectory.hpp
src/profiling/CounterIdMap.cpp
src/profiling/CounterIdMap.hpp
+ src/profiling/DeactivateTimelineReportingCommandHandler.cpp
+ src/profiling/DeactivateTimelineReportingCommandHandler.hpp
src/profiling/DirectoryCaptureCommandHandler.cpp
src/profiling/DirectoryCaptureCommandHandler.hpp
src/profiling/EncodeVersion.hpp
@@ -490,6 +494,8 @@ list(APPEND armnn_sources
src/profiling/ICounterDirectory.hpp
src/profiling/ICounterRegistry.hpp
src/profiling/ICounterValues.hpp
+ src/profiling/INotifyBackends.hpp
+ src/profiling/IReportStructure.hpp
src/profiling/ISendCounterPacket.hpp
src/profiling/ISendThread.hpp
src/profiling/IPacketBuffer.hpp
diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp
index 8391ed3b15..06d249ea8c 100644
--- a/include/armnn/IRuntime.hpp
+++ b/include/armnn/IRuntime.hpp
@@ -66,6 +66,7 @@ public:
, m_FileOnly(false)
, m_CapturePeriod(LOWEST_CAPTURE_PERIOD)
, m_FileFormat("binary")
+ , m_TimelineEnabled(false)
{}
bool m_EnableProfiling;
@@ -74,6 +75,7 @@ public:
bool m_FileOnly;
uint32_t m_CapturePeriod;
std::string m_FileFormat;
+ bool m_TimelineEnabled;
};
ExternalProfilingOptions m_ProfilingOptions;
diff --git a/include/armnn/backends/profiling/IBackendProfilingContext.hpp b/include/armnn/backends/profiling/IBackendProfilingContext.hpp
index 063ebc946d..77959e959b 100644
--- a/include/armnn/backends/profiling/IBackendProfilingContext.hpp
+++ b/include/armnn/backends/profiling/IBackendProfilingContext.hpp
@@ -22,6 +22,7 @@ public:
virtual Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds) = 0;
virtual std::vector<Timestamp> ReportCounterValues() = 0;
virtual bool EnableProfiling(bool flag) = 0;
+ virtual bool EnableTimelineReporting(bool flag) = 0;
};
using IBackendProfilingContextUniquePtr = std::unique_ptr<IBackendProfilingContext>;
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index f3d742c515..9d181e535a 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -263,6 +263,43 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net,
}
}
+void LoadedNetwork::SendNetworkStructure()
+{
+ Graph& order = m_OptimizedNetwork->GetGraph().TopologicalSort();
+ ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
+
+ std::unique_ptr<TimelineUtilityMethods> timelineUtils =
+ TimelineUtilityMethods::GetTimelineUtils(m_ProfilingService);
+
+ timelineUtils->CreateTypedEntity(networkGuid, LabelsAndEventClasses::NETWORK_GUID);
+
+ for (auto&& layer : order)
+ {
+ // Add layer to the post-optimisation network structure
+ AddLayerStructure(timelineUtils, *layer, networkGuid);
+ switch (layer->GetType())
+ {
+ case LayerType::Input:
+ case LayerType::Output:
+ {
+ // Inputs and outputs are treated in a special way - see EnqueueInput() and EnqueueOutput().
+ break;
+ }
+ default:
+ {
+ for (auto& workload : m_WorkloadQueue)
+ {
+ // Add workload to the post-optimisation network structure
+ AddWorkloadStructure(timelineUtils, workload, *layer);
+ }
+ break;
+ }
+ }
+ }
+ // Commit to send the post-optimisation network structure
+ timelineUtils->Commit();
+}
+
TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const
{
for (auto&& inputLayer : m_OptimizedNetwork->GetGraph().GetInputLayers())
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index 01e3442508..91379d78ed 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -56,6 +56,8 @@ public:
void RegisterDebugCallback(const DebugCallbackFunction& func);
+ void SendNetworkStructure();
+
private:
void AllocateWorkingMemory();
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 26636a81f7..dfcbf852e0 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -152,11 +152,32 @@ const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
return nullptr;
}
+void Runtime::ReportStructure() // armnn::profiling::IProfilingService& profilingService as param
+{
+ // No-op for the time being, but this may be useful in future to have the profilingService available
+ // if (profilingService.IsProfilingEnabled()){}
+
+ LoadedNetworks::iterator it = m_LoadedNetworks.begin();
+ while (it != m_LoadedNetworks.end())
+ {
+ auto& loadedNetwork = it->second;
+ loadedNetwork->SendNetworkStructure();
+ // Increment the Iterator to point to next entry
+ it++;
+ }
+}
+
Runtime::Runtime(const CreationOptions& options)
- : m_NetworkIdCounter(0)
+ : m_NetworkIdCounter(0),
+ m_ProfilingService(*this)
{
ARMNN_LOG(info) << "ArmNN v" << ARMNN_VERSION << "\n";
+ if ( options.m_ProfilingOptions.m_TimelineEnabled && !options.m_ProfilingOptions.m_EnableProfiling )
+ {
+ throw RuntimeException("It is not possible to enable timeline reporting without profiling being enabled");
+ }
+
// pass configuration info to the profiling service
m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions);
diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp
index 477b1169b1..d4b6dcbd63 100644
--- a/src/armnn/Runtime.hpp
+++ b/src/armnn/Runtime.hpp
@@ -16,13 +16,19 @@
#include <ProfilingService.hpp>
+#include <IProfilingService.hpp>
+#include <IReportStructure.hpp>
+
#include <mutex>
#include <unordered_map>
namespace armnn
{
+using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>;
+using IReportStructure = profiling::IReportStructure;
-class Runtime final : public IRuntime
+class Runtime final : public IRuntime,
+ public IReportStructure
{
public:
/// Loads a complete network into the Runtime.
@@ -79,6 +85,10 @@ public:
~Runtime();
+ //NOTE: we won't need the profiling service reference but it is good to pass the service
+ // in this way to facilitate other implementations down the road
+ virtual void ReportStructure() override;
+
private:
friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp
@@ -104,7 +114,9 @@ private:
mutable std::mutex m_Mutex;
- std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
+ /// Map of Loaded Networks with associated GUID as key
+ LoadedNetworks m_LoadedNetworks;
+
std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts;
int m_NetworkIdCounter;
diff --git a/src/armnn/test/RuntimeTests.cpp b/src/armnn/test/RuntimeTests.cpp
index c1150231bd..2fc4b50a54 100644
--- a/src/armnn/test/RuntimeTests.cpp
+++ b/src/armnn/test/RuntimeTests.cpp
@@ -371,7 +371,15 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef)
// Create runtime in which the test will run
armnn::IRuntime::CreationOptions options;
options.m_ProfilingOptions.m_EnableProfiling = true;
+ options.m_ProfilingOptions.m_TimelineEnabled = true;
+
armnn::Runtime runtime(options);
+ GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false);
+
+ profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime));
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected);
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck);
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::Active);
// build up the structure of the network
INetworkPtr net(INetwork::Create());
@@ -399,7 +407,6 @@ BOOST_AUTO_TEST_CASE(ProfilingEnableCpuRef)
armnn::NetworkId netId;
BOOST_TEST(runtime.LoadNetwork(netId, std::move(optNet)) == Status::Success);
- profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime));
profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager();
auto readableBuffer = bufferManager.GetReadableBuffer();
diff --git a/src/backends/backendsCommon/test/MockBackend.cpp b/src/backends/backendsCommon/test/MockBackend.cpp
index 8d40117741..116bf77c63 100644
--- a/src/backends/backendsCommon/test/MockBackend.cpp
+++ b/src/backends/backendsCommon/test/MockBackend.cpp
@@ -5,7 +5,6 @@
#include "MockBackend.hpp"
#include "MockBackendId.hpp"
-#include "armnn/backends/profiling/IBackendProfilingContext.hpp"
#include <armnn/BackendRegistry.hpp>
diff --git a/src/backends/backendsCommon/test/MockBackend.hpp b/src/backends/backendsCommon/test/MockBackend.hpp
index 6e415b9b52..e1570ff920 100644
--- a/src/backends/backendsCommon/test/MockBackend.hpp
+++ b/src/backends/backendsCommon/test/MockBackend.hpp
@@ -32,6 +32,7 @@ public:
MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling)
: m_BackendProfiling(std::move(backendProfiling))
, m_CapturePeriod(0)
+ , m_IsTimelineEnabled(true)
{}
~MockBackendProfilingContext() = default;
@@ -93,10 +94,22 @@ public:
return true;
}
+ bool EnableTimelineReporting(bool isEnabled)
+ {
+ m_IsTimelineEnabled = isEnabled;
+ return isEnabled;
+ }
+
+ bool TimelineReportingEnabled()
+ {
+ return m_IsTimelineEnabled;
+ }
+
private:
IBackendInternal::IBackendProfilingPtr m_BackendProfiling;
uint32_t m_CapturePeriod;
std::vector<uint16_t> m_ActiveCounters;
+ bool m_IsTimelineEnabled;
};
class MockBackendProfilingService
diff --git a/src/profiling/ActivateTimelineReportingCommandHandler.cpp b/src/profiling/ActivateTimelineReportingCommandHandler.cpp
new file mode 100644
index 0000000000..d762efc277
--- /dev/null
+++ b/src/profiling/ActivateTimelineReportingCommandHandler.cpp
@@ -0,0 +1,63 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ActivateTimelineReportingCommandHandler.hpp"
+#include "TimelineUtilityMethods.hpp"
+
+#include <armnn/Exceptions.hpp>
+#include <boost/format.hpp>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+void ActivateTimelineReportingCommandHandler::operator()(const Packet& packet)
+{
+ ProfilingState currentState = m_StateMachine.GetCurrentState();
+
+ if (!m_ReportStructure.has_value())
+ {
+ throw armnn::Exception(std::string("Profiling Service constructor must be initialised with an "
+ "IReportStructure argument in order to run timeline reporting"));
+ }
+
+ switch ( currentState )
+ {
+ case ProfilingState::Uninitialised:
+ case ProfilingState::NotConnected:
+ case ProfilingState::WaitingForAck:
+ throw RuntimeException(boost::str(
+ boost::format("Activate Timeline Reporting Command Handler invoked while in a wrong state: %1%")
+ % GetProfilingStateName(currentState)));
+ case ProfilingState::Active:
+ if ( !( packet.GetPacketFamily() == 0u && packet.GetPacketId() == 6u ))
+ {
+ throw armnn::Exception(std::string("Expected Packet family = 0, id = 6 but received family =")
+ + std::to_string(packet.GetPacketFamily())
+ + " id = " + std::to_string(packet.GetPacketId()));
+ }
+
+ m_SendTimelinePacket.SendTimelineMessageDirectoryPackage();
+
+ TimelineUtilityMethods::SendWellKnownLabelsAndEventClasses(m_SendTimelinePacket);
+
+ m_TimelineReporting = true;
+
+ m_ReportStructure.value().ReportStructure();
+
+ m_BackendNotifier.NotifyBackendsForTimelineReporting();
+
+ break;
+ default:
+ throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
+ % static_cast<int>(currentState)));
+ }
+}
+
+} // namespace profiling
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/profiling/ActivateTimelineReportingCommandHandler.hpp b/src/profiling/ActivateTimelineReportingCommandHandler.hpp
new file mode 100644
index 0000000000..ac11b46379
--- /dev/null
+++ b/src/profiling/ActivateTimelineReportingCommandHandler.hpp
@@ -0,0 +1,54 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "CommandHandlerFunctor.hpp"
+#include "ProfilingStateMachine.hpp"
+#include "Packet.hpp"
+#include "SendTimelinePacket.hpp"
+#include "IReportStructure.hpp"
+#include "armnn/Optional.hpp"
+#include "INotifyBackends.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class ActivateTimelineReportingCommandHandler : public CommandHandlerFunctor
+{
+public:
+ ActivateTimelineReportingCommandHandler(uint32_t familyId,
+ uint32_t packetId,
+ uint32_t version,
+ SendTimelinePacket& sendTimelinePacket,
+ ProfilingStateMachine& profilingStateMachine,
+ Optional<IReportStructure&> reportStructure,
+ std::atomic<bool>& timelineReporting,
+ INotifyBackends& notifyBackends)
+ : CommandHandlerFunctor(familyId, packetId, version),
+ m_SendTimelinePacket(sendTimelinePacket),
+ m_StateMachine(profilingStateMachine),
+ m_TimelineReporting(timelineReporting),
+ m_BackendNotifier(notifyBackends),
+ m_ReportStructure(reportStructure)
+ {}
+
+ void operator()(const Packet& packet) override;
+
+private:
+ SendTimelinePacket& m_SendTimelinePacket;
+ ProfilingStateMachine& m_StateMachine;
+ std::atomic<bool>& m_TimelineReporting;
+ INotifyBackends& m_BackendNotifier;
+
+ Optional<IReportStructure&> m_ReportStructure;
+};
+
+} // namespace profiling
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/profiling/CommandHandler.hpp b/src/profiling/CommandHandler.hpp
index 0cc23429cd..4bf820c5db 100644
--- a/src/profiling/CommandHandler.hpp
+++ b/src/profiling/CommandHandler.hpp
@@ -22,16 +22,16 @@ class CommandHandler
{
public:
CommandHandler(uint32_t timeout,
- bool stopAfterTimeout,
- CommandHandlerRegistry& commandHandlerRegistry,
- PacketVersionResolver& packetVersionResolver)
- : m_Timeout(timeout)
- , m_StopAfterTimeout(stopAfterTimeout)
- , m_IsRunning(false)
- , m_KeepRunning(false)
- , m_CommandThread()
- , m_CommandHandlerRegistry(commandHandlerRegistry)
- , m_PacketVersionResolver(packetVersionResolver)
+ bool stopAfterTimeout,
+ CommandHandlerRegistry& commandHandlerRegistry,
+ PacketVersionResolver& packetVersionResolver)
+ : m_Timeout(timeout),
+ m_StopAfterTimeout(stopAfterTimeout),
+ m_IsRunning(false),
+ m_KeepRunning(false),
+ m_CommandThread(),
+ m_CommandHandlerRegistry(commandHandlerRegistry),
+ m_PacketVersionResolver(packetVersionResolver)
{}
~CommandHandler() { Stop(); }
@@ -46,13 +46,13 @@ private:
void HandleCommands(IProfilingConnection& profilingConnection);
std::atomic<uint32_t> m_Timeout;
- std::atomic<bool> m_StopAfterTimeout;
- std::atomic<bool> m_IsRunning;
- std::atomic<bool> m_KeepRunning;
- std::thread m_CommandThread;
+ std::atomic<bool> m_StopAfterTimeout;
+ std::atomic<bool> m_IsRunning;
+ std::atomic<bool> m_KeepRunning;
+ std::thread m_CommandThread;
CommandHandlerRegistry& m_CommandHandlerRegistry;
- PacketVersionResolver& m_PacketVersionResolver;
+ PacketVersionResolver& m_PacketVersionResolver;
};
} // namespace profiling
diff --git a/src/profiling/DeactivateTimelineReportingCommandHandler.cpp b/src/profiling/DeactivateTimelineReportingCommandHandler.cpp
new file mode 100644
index 0000000000..dbfb053b3d
--- /dev/null
+++ b/src/profiling/DeactivateTimelineReportingCommandHandler.cpp
@@ -0,0 +1,53 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "DeactivateTimelineReportingCommandHandler.hpp"
+
+#include <armnn/Exceptions.hpp>
+#include <boost/format.hpp>
+
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+void DeactivateTimelineReportingCommandHandler::operator()(const Packet& packet)
+{
+ ProfilingState currentState = m_StateMachine.GetCurrentState();
+
+ switch ( currentState )
+ {
+ case ProfilingState::Uninitialised:
+ case ProfilingState::NotConnected:
+ case ProfilingState::WaitingForAck:
+ throw RuntimeException(boost::str(
+ boost::format("Deactivate Timeline Reporting Command Handler invoked while in a wrong state: %1%")
+ % GetProfilingStateName(currentState)));
+ case ProfilingState::Active:
+ if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 7u))
+ {
+ throw armnn::Exception(std::string("Expected Packet family = 0, id = 7 but received family =")
+ + std::to_string(packet.GetPacketFamily())
+ +" id = " + std::to_string(packet.GetPacketId()));
+ }
+
+ m_TimelineReporting.store(false);
+
+ // Notify Backends
+ m_BackendNotifier.NotifyBackendsForTimelineReporting();
+
+ break;
+ default:
+ throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
+ % static_cast<int>(currentState)));
+ }
+}
+
+} // namespace profiling
+
+} // namespace armnn
+
diff --git a/src/profiling/DeactivateTimelineReportingCommandHandler.hpp b/src/profiling/DeactivateTimelineReportingCommandHandler.hpp
new file mode 100644
index 0000000000..e06bae836f
--- /dev/null
+++ b/src/profiling/DeactivateTimelineReportingCommandHandler.hpp
@@ -0,0 +1,45 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "CommandHandlerFunctor.hpp"
+#include "Packet.hpp"
+#include "ProfilingStateMachine.hpp"
+#include "INotifyBackends.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class DeactivateTimelineReportingCommandHandler : public CommandHandlerFunctor
+{
+
+public:
+ DeactivateTimelineReportingCommandHandler(uint32_t familyId,
+ uint32_t packetId,
+ uint32_t version,
+ std::atomic<bool>& timelineReporting,
+ ProfilingStateMachine& profilingStateMachine,
+ INotifyBackends& notifyBackends)
+ : CommandHandlerFunctor(familyId, packetId, version)
+ , m_TimelineReporting(timelineReporting)
+ , m_StateMachine(profilingStateMachine)
+ , m_BackendNotifier(notifyBackends)
+ {}
+
+ void operator()(const Packet& packet) override;
+
+private:
+ std::atomic<bool>& m_TimelineReporting;
+ ProfilingStateMachine& m_StateMachine;
+ INotifyBackends& m_BackendNotifier;
+};
+
+} // namespace profiling
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/profiling/DirectoryCaptureCommandHandler.cpp b/src/profiling/DirectoryCaptureCommandHandler.cpp
index 65cac848ae..93cdde736e 100644
--- a/src/profiling/DirectoryCaptureCommandHandler.cpp
+++ b/src/profiling/DirectoryCaptureCommandHandler.cpp
@@ -281,7 +281,8 @@ std::vector<CounterDirectoryEventRecord> DirectoryCaptureCommandHandler::ReadEve
eventRecords[i].m_CounterDescription = GetStringNameFromBuffer(data, eventRecordOffset + descriptionOffset);
- eventRecords[i].m_CounterUnits = GetStringNameFromBuffer(data, eventRecordOffset + unitsOffset);
+ eventRecords[i].m_CounterUnits = unitsOffset == 0 ? Optional<std::string>() :
+ GetStringNameFromBuffer(data, eventRecordOffset + unitsOffset);
}
return eventRecords;
diff --git a/src/profiling/DirectoryCaptureCommandHandler.hpp b/src/profiling/DirectoryCaptureCommandHandler.hpp
index 03bbb1eb09..6b25714168 100644
--- a/src/profiling/DirectoryCaptureCommandHandler.hpp
+++ b/src/profiling/DirectoryCaptureCommandHandler.hpp
@@ -25,7 +25,7 @@ struct CounterDirectoryEventRecord
std::string m_CounterName;
uint16_t m_CounterSetUid;
uint16_t m_CounterUid;
- std::string m_CounterUnits;
+ Optional<std::string> m_CounterUnits;
uint16_t m_DeviceUid;
uint16_t m_MaxCounterUid;
};
diff --git a/src/profiling/INotifyBackends.hpp b/src/profiling/INotifyBackends.hpp
new file mode 100644
index 0000000000..217ebdef1c
--- /dev/null
+++ b/src/profiling/INotifyBackends.hpp
@@ -0,0 +1,24 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class INotifyBackends
+{
+public:
+ virtual ~INotifyBackends() {}
+ virtual void NotifyBackendsForTimelineReporting() = 0;
+};
+
+} // namespace profiling
+
+} // namespace armnn
+
diff --git a/src/profiling/IReportStructure.hpp b/src/profiling/IReportStructure.hpp
new file mode 100644
index 0000000000..1ae049733f
--- /dev/null
+++ b/src/profiling/IReportStructure.hpp
@@ -0,0 +1,24 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class IReportStructure
+{
+public:
+ virtual ~IReportStructure() {}
+ virtual void ReportStructure() = 0;
+};
+
+} // namespace profiling
+
+} // namespace armnn
+
diff --git a/src/profiling/PacketVersionResolver.cpp b/src/profiling/PacketVersionResolver.cpp
index 2c75067487..4178ace166 100644
--- a/src/profiling/PacketVersionResolver.cpp
+++ b/src/profiling/PacketVersionResolver.cpp
@@ -54,8 +54,17 @@ bool PacketKey::operator!=(const PacketKey& rhs) const
Version PacketVersionResolver::ResolvePacketVersion(uint32_t familyId, uint32_t packetId) const
{
- IgnoreUnused(familyId, packetId);
- // NOTE: For now every packet specification is at version 1.0.0
+ const PacketKey packetKey(familyId, packetId);
+
+ if( packetKey == ActivateTimeLinePacket )
+ {
+ return Version(1, 1, 0);
+ }
+ if( packetKey == DectivateTimeLinePacket )
+ {
+ return Version(1, 1, 0);
+ }
+
return Version(1, 0, 0);
}
diff --git a/src/profiling/PacketVersionResolver.hpp b/src/profiling/PacketVersionResolver.hpp
index e959ed548e..6222eb02e8 100644
--- a/src/profiling/PacketVersionResolver.hpp
+++ b/src/profiling/PacketVersionResolver.hpp
@@ -33,6 +33,9 @@ private:
uint32_t m_PacketId;
};
+static const PacketKey ActivateTimeLinePacket(0 , 6);
+static const PacketKey DectivateTimeLinePacket(0 , 7);
+
class PacketVersionResolver final
{
public:
diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp
index e42ef134dc..3a8f3f83a3 100644
--- a/src/profiling/ProfilingService.cpp
+++ b/src/profiling/ProfilingService.cpp
@@ -34,6 +34,7 @@ void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOpti
{
// Update the profiling options
m_Options = options;
+ m_TimelineReporting = options.m_TimelineEnabled;
// Check if the profiling service needs to be reset
if (resetProfilingService)
@@ -431,7 +432,7 @@ void ProfilingService::Reset()
// ...finally reset the profiling state machine
m_StateMachine.Reset();
m_BackendProfilingContexts.clear();
- m_MaxGlobalCounterId = armnn::profiling::INFERENCES_RUN;
+ m_MaxGlobalCounterId = armnn::profiling::MAX_ARMNN_COUNTER;
}
void ProfilingService::Stop()
@@ -463,11 +464,22 @@ inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
}
}
+void ProfilingService::NotifyBackendsForTimelineReporting()
+{
+ BackendProfilingContext::iterator it = m_BackendProfilingContexts.begin();
+ while (it != m_BackendProfilingContexts.end())
+ {
+ auto& backendProfilingContext = it->second;
+ backendProfilingContext->EnableTimelineReporting(m_TimelineReporting);
+ // Increment the Iterator to point to next entry
+ it++;
+ }
+}
+
ProfilingService::~ProfilingService()
{
Stop();
}
-
} // namespace profiling
} // namespace armnn
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index f6e409daeb..df7bd8f857 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -5,14 +5,17 @@
#pragma once
+#include "ActivateTimelineReportingCommandHandler.hpp"
#include "BufferManager.hpp"
#include "CommandHandler.hpp"
#include "ConnectionAcknowledgedCommandHandler.hpp"
#include "CounterDirectory.hpp"
#include "CounterIdMap.hpp"
+#include "DeactivateTimelineReportingCommandHandler.hpp"
#include "ICounterRegistry.hpp"
#include "ICounterValues.hpp"
#include "IProfilingService.hpp"
+#include "IReportStructure.hpp"
#include "PeriodicCounterCapture.hpp"
#include "PeriodicCounterSelectionCommandHandler.hpp"
#include "PerJobCounterSelectionCommandHandler.hpp"
@@ -24,6 +27,7 @@
#include "SendThread.hpp"
#include "SendTimelinePacket.hpp"
#include "TimelinePacketWriterFactory.hpp"
+#include "INotifyBackends.hpp"
#include <armnn/backends/profiling/IBackendProfilingContext.hpp>
namespace armnn
@@ -32,14 +36,14 @@ namespace armnn
namespace profiling
{
// Static constants describing ArmNN's counter UID's
-static const uint16_t NETWORK_LOADS = 0;
-static const uint16_t NETWORK_UNLOADS = 1;
-static const uint16_t REGISTERED_BACKENDS = 2;
-static const uint16_t UNREGISTERED_BACKENDS = 3;
-static const uint16_t INFERENCES_RUN = 4;
-static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN;
-
-class ProfilingService : public IReadWriteCounterValues, public IProfilingService
+static const uint16_t NETWORK_LOADS = 0;
+static const uint16_t NETWORK_UNLOADS = 1;
+static const uint16_t REGISTERED_BACKENDS = 2;
+static const uint16_t UNREGISTERED_BACKENDS = 3;
+static const uint16_t INFERENCES_RUN = 4;
+static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN;
+
+class ProfilingService : public IReadWriteCounterValues, public IProfilingService, public INotifyBackends
{
public:
using ExternalProfilingOptions = IRuntime::CreationOptions::ExternalProfilingOptions;
@@ -47,10 +51,12 @@ public:
using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>;
using CounterIndices = std::vector<std::atomic<uint32_t>*>;
using CounterValues = std::list<std::atomic<uint32_t>>;
+ using BackendProfilingContext = std::unordered_map<BackendId,
+ std::shared_ptr<armnn::profiling::IBackendProfilingContext>>;
- // Default constructor/destructor kept protected for testing
- ProfilingService()
+ ProfilingService(Optional<IReportStructure&> reportStructure = EmptyOptional())
: m_Options()
+ , m_TimelineReporting(false)
, m_CounterDirectory()
, m_ProfilingConnectionFactory(new ProfilingConnectionFactory())
, m_ProfilingConnection()
@@ -97,6 +103,22 @@ public:
5,
m_PacketVersionResolver.ResolvePacketVersion(0, 5).GetEncodedValue(),
m_StateMachine)
+ , m_ActivateTimelineReportingCommandHandler(0,
+ 6,
+ m_PacketVersionResolver.ResolvePacketVersion(0, 6)
+ .GetEncodedValue(),
+ m_SendTimelinePacket,
+ m_StateMachine,
+ reportStructure,
+ m_TimelineReporting,
+ *this)
+ , m_DeactivateTimelineReportingCommandHandler(0,
+ 7,
+ m_PacketVersionResolver.ResolvePacketVersion(0, 7)
+ .GetEncodedValue(),
+ m_TimelineReporting,
+ m_StateMachine,
+ *this)
, m_TimelinePacketWriterFactory(m_BufferManager)
, m_MaxGlobalCounterId(armnn::profiling::INFERENCES_RUN)
{
@@ -111,6 +133,10 @@ public:
// Register the "Per-Job Counter Selection" command handler
m_CommandHandlerRegistry.RegisterFunctor(&m_PerJobCounterSelectionCommandHandler);
+
+ m_CommandHandlerRegistry.RegisterFunctor(&m_ActivateTimelineReportingCommandHandler);
+
+ m_CommandHandlerRegistry.RegisterFunctor(&m_DeactivateTimelineReportingCommandHandler);
}
~ProfilingService();
@@ -131,6 +157,9 @@ public:
void AddBackendProfilingContext(const BackendId backendId,
std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext);
+ // Enable the recording of timeline events and entities
+ void NotifyBackendsForTimelineReporting() override;
+
const ICounterDirectory& GetCounterDirectory() const;
ICounterRegistry& GetCounterRegistry();
ProfilingState GetCurrentState() const;
@@ -168,13 +197,15 @@ public:
return m_SendCounterPacket;
}
- /// Check if the profiling is enabled
- bool IsEnabled() { return m_Options.m_EnableProfiling; }
-
static ProfilingDynamicGuid GetNextGuid();
static ProfilingStaticGuid GetStaticId(const std::string& str);
+ bool IsTimelineReportingEnabled()
+ {
+ return m_TimelineReporting;
+ }
+
private:
//Copy/move constructors/destructors and copy/move assignment operators are deleted
ProfilingService(const ProfilingService&) = delete;
@@ -192,32 +223,37 @@ private:
void CheckCounterUid(uint16_t counterUid) const;
// Profiling service components
- ExternalProfilingOptions m_Options;
- CounterDirectory m_CounterDirectory;
- CounterIdMap m_CounterIdMap;
+ ExternalProfilingOptions m_Options;
+ std::atomic<bool> m_TimelineReporting;
+ CounterDirectory m_CounterDirectory;
+ CounterIdMap m_CounterIdMap;
IProfilingConnectionFactoryPtr m_ProfilingConnectionFactory;
- IProfilingConnectionPtr m_ProfilingConnection;
- ProfilingStateMachine m_StateMachine;
- CounterIndices m_CounterIndex;
- CounterValues m_CounterValues;
- CommandHandlerRegistry m_CommandHandlerRegistry;
- PacketVersionResolver m_PacketVersionResolver;
- CommandHandler m_CommandHandler;
- BufferManager m_BufferManager;
- SendCounterPacket m_SendCounterPacket;
- SendThread m_SendThread;
- SendTimelinePacket m_SendTimelinePacket;
+ IProfilingConnectionPtr m_ProfilingConnection;
+ ProfilingStateMachine m_StateMachine;
+ CounterIndices m_CounterIndex;
+ CounterValues m_CounterValues;
+ CommandHandlerRegistry m_CommandHandlerRegistry;
+ PacketVersionResolver m_PacketVersionResolver;
+ CommandHandler m_CommandHandler;
+ BufferManager m_BufferManager;
+ SendCounterPacket m_SendCounterPacket;
+ SendThread m_SendThread;
+ SendTimelinePacket m_SendTimelinePacket;
+
Holder m_Holder;
+
PeriodicCounterCapture m_PeriodicCounterCapture;
- ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
- RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler;
- PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler;
- PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler;
+
+ ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
+ RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler;
+ PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler;
+ PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler;
+ ActivateTimelineReportingCommandHandler m_ActivateTimelineReportingCommandHandler;
+ DeactivateTimelineReportingCommandHandler m_DeactivateTimelineReportingCommandHandler;
TimelinePacketWriterFactory m_TimelinePacketWriterFactory;
- std::unordered_map<BackendId,
- std::shared_ptr<armnn::profiling::IBackendProfilingContext>> m_BackendProfilingContexts;
- uint16_t m_MaxGlobalCounterId;
+ BackendProfilingContext m_BackendProfilingContexts;
+ uint16_t m_MaxGlobalCounterId;
static ProfilingGuidGenerator m_GuidGenerator;
diff --git a/src/profiling/ProfilingUtils.cpp b/src/profiling/ProfilingUtils.cpp
index 002eeb9616..e419769600 100644
--- a/src/profiling/ProfilingUtils.cpp
+++ b/src/profiling/ProfilingUtils.cpp
@@ -96,17 +96,8 @@ void WriteBytes(const IPacketBufferPtr& packetBuffer, unsigned int offset, cons
uint32_t ConstructHeader(uint32_t packetFamily,
uint32_t packetId)
{
- return ((packetFamily & 0x3F) << 26)|
- ((packetId & 0x3FF) << 16);
-}
-
-uint32_t ConstructHeader(uint32_t packetFamily,
- uint32_t packetClass,
- uint32_t packetType)
-{
- return ((packetFamily & 0x3F) << 26)|
- ((packetClass & 0x3FF) << 19)|
- ((packetType & 0x3FFF) << 16);
+ return (( packetFamily & 0x0000003F ) << 26 )|
+ (( packetId & 0x000003FF ) << 16 );
}
void WriteUint64(const std::unique_ptr<IPacketBuffer>& packetBuffer, unsigned int offset, uint64_t value)
diff --git a/src/profiling/ProfilingUtils.hpp b/src/profiling/ProfilingUtils.hpp
index 37ab88cb6f..5888ef0b8c 100644
--- a/src/profiling/ProfilingUtils.hpp
+++ b/src/profiling/ProfilingUtils.hpp
@@ -146,8 +146,6 @@ void WriteBytes(const IPacketBuffer& packetBuffer, unsigned int offset, const vo
uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetId);
-uint32_t ConstructHeader(uint32_t packetFamily, uint32_t packetClass, uint32_t packetType);
-
void WriteUint64(const IPacketBufferPtr& packetBuffer, unsigned int offset, uint64_t value);
void WriteUint32(const IPacketBufferPtr& packetBuffer, unsigned int offset, uint32_t value);
diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp
index 942ccc7b59..ae4bab91e7 100644
--- a/src/profiling/SendCounterPacket.cpp
+++ b/src/profiling/SendCounterPacket.cpp
@@ -181,10 +181,34 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category,
BOOST_ASSERT(category);
const std::string& categoryName = category->m_Name;
- const std::vector<uint16_t> categoryCounters = category->m_Counters;
-
BOOST_ASSERT(!categoryName.empty());
+ // Remove any duplicate counters
+ std::vector<uint16_t> categoryCounters;
+ for (size_t counterIndex = 0; counterIndex < category->m_Counters.size(); ++counterIndex)
+ {
+ uint16_t counterUid = category->m_Counters.at(counterIndex);
+ auto it = counters.find(counterUid);
+ if (it == counters.end())
+ {
+ errorMessage = boost::str(boost::format("Counter (%1%) not found in category (%2%)")
+ % counterUid % category->m_Name );
+ return false;
+ }
+
+ const CounterPtr& counter = it->second;
+
+ if (counterUid == counter->m_MaxCounterUid)
+ {
+ categoryCounters.emplace_back(counterUid);
+ }
+ }
+ if (categoryCounters.empty())
+ {
+ errorMessage = boost::str(boost::format("No valid counters found in category (%1%)")% categoryName);
+ return false;
+ }
+
// Utils
size_t uint32_t_size = sizeof(uint32_t);
@@ -203,7 +227,7 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category,
std::vector<uint32_t> categoryNameBuffer;
if (!StringToSwTraceString<SwTraceNameCharPolicy>(categoryName, categoryNameBuffer))
{
- errorMessage = boost::str(boost::format("Cannot convert the name of category \"%1%\" to an SWTrace namestring")
+ errorMessage = boost::str(boost::format("Cannot convert the name of category (%1%) to an SWTrace namestring")
% categoryName);
return false;
}
@@ -221,7 +245,6 @@ bool SendCounterPacket::CreateCategoryRecord(const CategoryPtr& category,
{
uint16_t counterUid = categoryCounters.at(counterIndex);
auto it = counters.find(counterUid);
- BOOST_ASSERT(it != counters.end());
const CounterPtr& counter = it->second;
EventRecord& eventRecord = eventRecords.at(eventRecordIndex);
@@ -299,7 +322,7 @@ bool SendCounterPacket::CreateDeviceRecord(const DevicePtr& device,
std::vector<uint32_t> deviceNameBuffer;
if (!StringToSwTraceString<SwTraceCharPolicy>(deviceName, deviceNameBuffer))
{
- errorMessage = boost::str(boost::format("Cannot convert the name of device %1% (\"%2%\") to an SWTrace string")
+ errorMessage = boost::str(boost::format("Cannot convert the name of device %1% (%2%) to an SWTrace string")
% deviceUid
% deviceName);
return false;
@@ -349,7 +372,7 @@ bool SendCounterPacket::CreateCounterSetRecord(const CounterSetPtr& counterSet,
std::vector<uint32_t> counterSetNameBuffer;
if (!StringToSwTraceString<SwTraceNameCharPolicy>(counterSet->m_Name, counterSetNameBuffer))
{
- errorMessage = boost::str(boost::format("Cannot convert the name of counter set %1% (\"%2%\") to "
+ errorMessage = boost::str(boost::format("Cannot convert the name of counter set %1% (%2%) to "
"an SWTrace namestring")
% counterSetUid
% counterSetName);
@@ -441,7 +464,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter,
std::vector<uint32_t> counterNameBuffer;
if (!StringToSwTraceString<SwTraceCharPolicy>(counterName, counterNameBuffer))
{
- errorMessage = boost::str(boost::format("Cannot convert the name of counter %1% (name: \"%2%\") "
+ errorMessage = boost::str(boost::format("Cannot convert the name of counter %1% (name: %2%) "
"to an SWTrace string")
% counterUid
% counterName);
@@ -457,7 +480,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter,
std::vector<uint32_t> counterDescriptionBuffer;
if (!StringToSwTraceString<SwTraceCharPolicy>(counterDescription, counterDescriptionBuffer))
{
- errorMessage = boost::str(boost::format("Cannot convert the description of counter %1% (description: \"%2%\") "
+ errorMessage = boost::str(boost::format("Cannot convert the description of counter %1% (description: %2%) "
"to an SWTrace string")
% counterUid
% counterName);
@@ -481,7 +504,7 @@ bool SendCounterPacket::CreateEventRecord(const CounterPtr& counter,
// Convert the counter units into a SWTrace namestring
if (!StringToSwTraceString<SwTraceNameCharPolicy>(counterUnits, counterUnitsBuffer))
{
- errorMessage = boost::str(boost::format("Cannot convert the units of counter %1% (units: \"%2%\") "
+ errorMessage = boost::str(boost::format("Cannot convert the units of counter %1% (units: %2%) "
"to an SWTrace string")
% counterUid
% counterName);
diff --git a/src/profiling/TimelineUtilityMethods.cpp b/src/profiling/TimelineUtilityMethods.cpp
index 8a39ea7c4a..de30b4d4ef 100644
--- a/src/profiling/TimelineUtilityMethods.cpp
+++ b/src/profiling/TimelineUtilityMethods.cpp
@@ -14,7 +14,7 @@ namespace profiling
std::unique_ptr<TimelineUtilityMethods> TimelineUtilityMethods::GetTimelineUtils(ProfilingService& profilingService)
{
- if (profilingService.IsProfilingEnabled())
+ if (profilingService.GetCurrentState() == ProfilingState::Active && profilingService.IsTimelineReportingEnabled())
{
std::unique_ptr<ISendTimelinePacket> sendTimelinepacket = profilingService.GetSendTimelinePacket();
return std::make_unique<TimelineUtilityMethods>(sendTimelinepacket);
diff --git a/src/profiling/test/ProfilingMocks.hpp b/src/profiling/test/ProfilingMocks.hpp
index eeb641e878..ada55d8dff 100644
--- a/src/profiling/test/ProfilingMocks.hpp
+++ b/src/profiling/test/ProfilingMocks.hpp
@@ -51,6 +51,8 @@ public:
PerJobCounterSelection,
TimelineMessageDirectory,
PeriodicCounterCapture,
+ ActivateTimelineReporting,
+ DeactivateTimelineReporting,
Unknown
};
@@ -85,7 +87,7 @@ public:
switch (packetFamily)
{
case 0:
- packetType = packetId < 6 ? PacketType(packetId) : PacketType::Unknown;
+ packetType = packetId < 8 ? PacketType(packetId) : PacketType::Unknown;
break;
case 1:
packetType = packetId == 0 ? PacketType::TimelineMessageDirectory : PacketType::Unknown;
@@ -628,7 +630,8 @@ public:
const CaptureData& captureData) :
m_SendCounterPacket(mockBufferManager),
m_IsProfilingEnabled(isProfilingEnabled),
- m_CaptureData(captureData) {}
+ m_CaptureData(captureData)
+ {}
/// Return the next random Guid in the sequence
ProfilingDynamicGuid NextGuid() override
@@ -682,10 +685,10 @@ public:
private:
ProfilingGuidGenerator m_GuidGenerator;
- CounterIdMap m_CounterMapping;
- SendCounterPacket m_SendCounterPacket;
- bool m_IsProfilingEnabled;
- CaptureData m_CaptureData;
+ CounterIdMap m_CounterMapping;
+ SendCounterPacket m_SendCounterPacket;
+ bool m_IsProfilingEnabled;
+ CaptureData m_CaptureData;
};
} // namespace profiling
diff --git a/src/profiling/test/ProfilingTestUtils.cpp b/src/profiling/test/ProfilingTestUtils.cpp
index 244051c785..8de69f14ec 100644
--- a/src/profiling/test/ProfilingTestUtils.cpp
+++ b/src/profiling/test/ProfilingTestUtils.cpp
@@ -297,7 +297,14 @@ void VerifyPostOptimisationStructureTestImpl(armnn::BackendId backendId)
// Create runtime in which test will run
armnn::IRuntime::CreationOptions options;
options.m_ProfilingOptions.m_EnableProfiling = true;
+ options.m_ProfilingOptions.m_TimelineEnabled = true;
armnn::Runtime runtime(options);
+ GetProfilingService(&runtime).ResetExternalProfilingOptions(options.m_ProfilingOptions, false);
+
+ profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime));
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected);
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck);
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::Active);
// build up the structure of the network
INetworkPtr net(INetwork::Create());
@@ -363,7 +370,6 @@ void VerifyPostOptimisationStructureTestImpl(armnn::BackendId backendId)
armnn::NetworkId netId;
BOOST_TEST(runtime.LoadNetwork(netId, std::move(optNet)) == Status::Success);
- profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime));
profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager();
auto readableBuffer = bufferManager.GetReadableBuffer();
diff --git a/src/profiling/test/ProfilingTestUtils.hpp b/src/profiling/test/ProfilingTestUtils.hpp
index 459d62435b..816ffd3dc6 100644
--- a/src/profiling/test/ProfilingTestUtils.hpp
+++ b/src/profiling/test/ProfilingTestUtils.hpp
@@ -6,6 +6,7 @@
#pragma once
#include "ProfilingUtils.hpp"
+#include "Runtime.hpp"
#include <armnn/BackendId.hpp>
#include <armnn/Optional.hpp>
@@ -68,6 +69,11 @@ public:
return GetBufferManager(m_ProfilingService);
}
armnn::profiling::ProfilingService& m_ProfilingService;
+
+ void ForceTransitionToState(ProfilingState newState)
+ {
+ TransitionToState(m_ProfilingService, newState);
+ }
};
} // namespace profiling
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 74b93d7c4f..f252579022 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -4,6 +4,7 @@
//
#include "ProfilingTests.hpp"
+#include "ProfilingTestUtils.hpp"
#include <backends/BackendProfiling.hpp>
#include <CommandHandler.hpp>
@@ -1823,6 +1824,132 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
BOOST_TEST(period == armnn::LOWEST_CAPTURE_PERIOD); // capture period
}
+BOOST_AUTO_TEST_CASE(CheckTimelineActivationAndDeactivation)
+{
+ class TestReportStructure : public IReportStructure
+ {
+ public:
+ virtual void ReportStructure() override
+ {
+ m_ReportStructureCalled = true;
+ }
+
+ bool m_ReportStructureCalled = false;
+ };
+
+ class TestNotifyBackends : public INotifyBackends
+ {
+ public:
+ TestNotifyBackends() : m_timelineReporting(false) {}
+ virtual void NotifyBackendsForTimelineReporting() override
+ {
+ m_TestNotifyBackendsCalled = m_timelineReporting.load();
+ }
+
+ bool m_TestNotifyBackendsCalled = false;
+ std::atomic<bool> m_timelineReporting;
+ };
+
+ PacketVersionResolver packetVersionResolver;
+
+ BufferManager bufferManager(512);
+ SendTimelinePacket sendTimelinePacket(bufferManager);
+ ProfilingStateMachine stateMachine;
+ TestReportStructure testReportStructure;
+ TestNotifyBackends testNotifyBackends;
+
+ profiling::ActivateTimelineReportingCommandHandler activateTimelineReportingCommandHandler(0,
+ 6,
+ packetVersionResolver.ResolvePacketVersion(0, 6)
+ .GetEncodedValue(),
+ sendTimelinePacket,
+ stateMachine,
+ testReportStructure,
+ testNotifyBackends.m_timelineReporting,
+ testNotifyBackends);
+
+ // Write an "ActivateTimelineReporting" packet into the mock profiling connection, to simulate an input from an
+ // external profiling service
+ const uint32_t packetFamily1 = 0;
+ const uint32_t packetId1 = 6;
+ uint32_t packetHeader1 = ConstructHeader(packetFamily1, packetId1);
+
+ // Create the ActivateTimelineReportingPacket
+ Packet ActivateTimelineReportingPacket(packetHeader1); // Length == 0
+
+ BOOST_CHECK_THROW(
+ activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception);
+
+ stateMachine.TransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK_THROW(
+ activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception);
+
+ stateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ BOOST_CHECK_THROW(
+ activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket), armnn::Exception);
+
+ stateMachine.TransitionToState(ProfilingState::Active);
+ activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket);
+
+ BOOST_CHECK(testReportStructure.m_ReportStructureCalled);
+ BOOST_CHECK(testNotifyBackends.m_TestNotifyBackendsCalled);
+ BOOST_CHECK(testNotifyBackends.m_timelineReporting.load());
+
+ DeactivateTimelineReportingCommandHandler deactivateTimelineReportingCommandHandler(0,
+ 7,
+ packetVersionResolver.ResolvePacketVersion(0, 7).GetEncodedValue(),
+ testNotifyBackends.m_timelineReporting,
+ stateMachine,
+ testNotifyBackends);
+
+ const uint32_t packetFamily2 = 0;
+ const uint32_t packetId2 = 7;
+ uint32_t packetHeader2 = ConstructHeader(packetFamily2, packetId2);
+
+ // Create the DeactivateTimelineReportingPacket
+ Packet deactivateTimelineReportingPacket(packetHeader2); // Length == 0
+
+ stateMachine.Reset();
+ BOOST_CHECK_THROW(
+ deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception);
+
+ stateMachine.TransitionToState(ProfilingState::NotConnected);
+ BOOST_CHECK_THROW(
+ deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception);
+
+ stateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ BOOST_CHECK_THROW(
+ deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket), armnn::Exception);
+
+ stateMachine.TransitionToState(ProfilingState::Active);
+ deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket);
+
+ BOOST_CHECK(!testNotifyBackends.m_TestNotifyBackendsCalled);
+ BOOST_CHECK(!testNotifyBackends.m_timelineReporting.load());
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceNotActive)
+{
+ using namespace armnn;
+ using namespace armnn::profiling;
+
+ // Create runtime in which the test will run
+ armnn::IRuntime::CreationOptions options;
+ options.m_ProfilingOptions.m_EnableProfiling = true;
+
+ armnn::Runtime runtime(options);
+ profiling::ProfilingServiceRuntimeHelper profilingServiceHelper(GetProfilingService(&runtime));
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected);
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck);
+ profilingServiceHelper.ForceTransitionToState(ProfilingState::Active);
+
+ profiling::BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager();
+ auto readableBuffer = bufferManager.GetReadableBuffer();
+
+ // Profiling is enabled, the post-optimisation structure should be created
+ BOOST_CHECK(readableBuffer == nullptr);
+}
+
BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged)
{
using boost::numeric_cast;
@@ -3395,8 +3522,7 @@ BOOST_AUTO_TEST_CASE(CheckRegisterCounters)
MockBufferManager mockBuffer(1024);
CaptureData captureData;
- MockProfilingService mockProfilingService(
- mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData);
+ MockProfilingService mockProfilingService(mockBuffer, options.m_ProfilingOptions.m_EnableProfiling, captureData);
armnn::BackendId cpuRefId(armnn::Compute::CpuRef);
mockProfilingService.RegisterMapping(6, 0, cpuRefId);
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp
index a5971e0a4b..d1052cea97 100644
--- a/src/profiling/test/ProfilingTests.hpp
+++ b/src/profiling/test/ProfilingTests.hpp
@@ -242,7 +242,7 @@ public:
uint32_t length = 0,
uint32_t timeout = 1000)
{
- long packetCount = mockProfilingConnection->CheckForPacket({packetType, length});
+ long packetCount = mockProfilingConnection->CheckForPacket({ packetType, length });
// The first packet we receive may not be the one we are looking for, so keep looping until till we find it,
// or until WaitForPacketsSent times out
while(packetCount == 0 && timeout != 0)
diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp
index d7dc7e2d9e..51f049ddc6 100644
--- a/src/profiling/test/SendCounterPacketTests.cpp
+++ b/src/profiling/test/SendCounterPacketTests.cpp
@@ -1172,7 +1172,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
BOOST_CHECK(counterDirectory.GetCategoryCount() == 2);
BOOST_CHECK(category2);
- uint16_t numberOfCores = 3;
+ uint16_t numberOfCores = 4;
// Register a counter associated to "category1"
const Counter* counter1 = nullptr;
@@ -1186,7 +1186,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
"counter1description",
std::string("counter1units"),
numberOfCores));
- BOOST_CHECK(counterDirectory.GetCounterCount() == 3);
+ BOOST_CHECK(counterDirectory.GetCounterCount() == 4);
BOOST_CHECK(counter1);
// Register a counter associated to "category1"
@@ -1203,7 +1203,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
armnn::EmptyOptional(),
device2->m_Uid,
0));
- BOOST_CHECK(counterDirectory.GetCounterCount() == 4);
+ BOOST_CHECK(counterDirectory.GetCounterCount() == 5);
BOOST_CHECK(counter2);
// Register a counter associated to "category2"
@@ -1217,7 +1217,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
"counter3",
"counter3description",
armnn::EmptyOptional(),
- 5,
+ numberOfCores,
device2->m_Uid,
counterSet1->m_Uid));
BOOST_CHECK(counterDirectory.GetCounterCount() == 9);
@@ -1236,7 +1236,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
uint32_t packetHeaderWord1 = ReadUint32(readBuffer, 4);
BOOST_TEST(((packetHeaderWord0 >> 26) & 0x3F) == 0); // packet_family
BOOST_TEST(((packetHeaderWord0 >> 16) & 0x3FF) == 2); // packet_id
- BOOST_TEST(packetHeaderWord1 == 928); // data_length
+ BOOST_TEST(packetHeaderWord1 == 432); // data_length
// Check the body header
uint32_t bodyHeaderWord0 = ReadUint32(readBuffer, 8);
@@ -1269,7 +1269,7 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
uint32_t categoryRecordOffset0 = ReadUint32(readBuffer, 44);
uint32_t categoryRecordOffset1 = ReadUint32(readBuffer, 48);
BOOST_TEST(categoryRecordOffset0 == 64); // Category record offset for "category1"
- BOOST_TEST(categoryRecordOffset1 == 472); // Category record offset for "category2"
+ BOOST_TEST(categoryRecordOffset1 == 168); // Category record offset for "category2"
// Get the device record pool offset
uint32_t uint32_t_size = sizeof(uint32_t);
@@ -1584,7 +1584,8 @@ BOOST_AUTO_TEST_CASE(SendCounterDirectoryPacketTest2)
const Category* category = counterDirectory.GetCategory(categoryRecord.name);
BOOST_CHECK(category);
BOOST_CHECK(category->m_Name == categoryRecord.name);
- BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count);
+ BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count + static_cast<size_t>(numberOfCores) -1);
+ BOOST_CHECK(category->m_Counters.size() == categoryRecord.event_count + static_cast<size_t>(numberOfCores) -1);
// Check that the event records are correct
for (const EventRecord& eventRecord : categoryRecord.event_records)
diff --git a/src/profiling/test/SendTimelinePacketTests.cpp b/src/profiling/test/SendTimelinePacketTests.cpp
index 98b161f65e..4a13ebf824 100644
--- a/src/profiling/test/SendTimelinePacketTests.cpp
+++ b/src/profiling/test/SendTimelinePacketTests.cpp
@@ -15,6 +15,7 @@
#include <LabelsAndEventClasses.hpp>
#include <functional>
+#include <Runtime.hpp>
using namespace armnn::profiling;
@@ -395,10 +396,12 @@ BOOST_AUTO_TEST_CASE(SendTimelinePacketTests3)
BOOST_AUTO_TEST_CASE(GetGuidsFromProfilingService)
{
- armnn::IRuntime::CreationOptions::ExternalProfilingOptions options;
- options.m_EnableProfiling = true;
- armnn::profiling::ProfilingService profilingService;
- profilingService.ResetExternalProfilingOptions(options, true);
+ armnn::IRuntime::CreationOptions options;
+ options.m_ProfilingOptions.m_EnableProfiling = true;
+ armnn::Runtime runtime(options);
+ armnn::profiling::ProfilingService profilingService(runtime);
+
+ profilingService.ResetExternalProfilingOptions(options.m_ProfilingOptions, true);
ProfilingStaticGuid staticGuid = profilingService.GetStaticId("dummy");
std::hash<std::string> hasher;
uint64_t hash = static_cast<uint64_t>(hasher("dummy"));
diff --git a/tests/profiling/gatordmock/GatordMockMain.cpp b/tests/profiling/gatordmock/GatordMockMain.cpp
index e19461f6cb..029c58f5e8 100644
--- a/tests/profiling/gatordmock/GatordMockMain.cpp
+++ b/tests/profiling/gatordmock/GatordMockMain.cpp
@@ -3,16 +3,10 @@
// SPDX-License-Identifier: MIT
//
-#include "PacketVersionResolver.hpp"
#include "CommandFileParser.hpp"
#include "CommandLineProcessor.hpp"
-#include "DirectoryCaptureCommandHandler.hpp"
#include "GatordMockService.hpp"
-#include "PeriodicCounterCaptureCommandHandler.hpp"
-#include "PeriodicCounterSelectionResponseHandler.hpp"
#include <TimelineDecoder.hpp>
-#include <TimelineDirectoryCaptureCommandHandler.hpp>
-#include <TimelineCaptureCommandHandler.hpp>
#include <iostream>
#include <string>
@@ -32,38 +26,7 @@ void exit_capture(int signum)
bool CreateMockService(armnnUtils::Sockets::Socket clientConnection, std::string commandFile, bool isEchoEnabled)
{
- profiling::PacketVersionResolver packetVersionResolver;
- // Create the Command Handler Registry
- profiling::CommandHandlerRegistry registry;
-
- timelinedecoder::TimelineDecoder timelineDecoder;
- timelineDecoder.SetDefaultCallbacks();
-
- // This functor will receive back the selection response packet.
- PeriodicCounterSelectionResponseHandler periodicCounterSelectionResponseHandler(
- 0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue());
- // This functor will receive the counter data.
- PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler(
- 3, 0, packetVersionResolver.ResolvePacketVersion(3, 0).GetEncodedValue());
-
- profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
- 0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), false);
-
- timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler(
- 1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder);
-
- timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler(
- 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
- timelineCaptureCommandHandler, false);
-
- // Register different derived functors
- registry.RegisterFunctor(&periodicCounterSelectionResponseHandler);
- registry.RegisterFunctor(&counterCaptureCommandHandler);
- registry.RegisterFunctor(&directoryCaptureCommandHandler);
- registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler);
- registry.RegisterFunctor(&timelineCaptureCommandHandler);
-
- GatordMockService mockService(clientConnection, registry, isEchoEnabled);
+ GatordMockService mockService(clientConnection, isEchoEnabled);
// Send receive the strweam metadata and send connection ack.
if (!mockService.WaitForStreamMetaData())
@@ -82,11 +45,6 @@ bool CreateMockService(armnnUtils::Sockets::Socket clientConnection, std::string
// Once we've finished processing the file wait for the receiving thread to close.
mockService.WaitForReceivingThread();
- if(isEchoEnabled)
- {
- timelineDecoder.print();
- }
-
return EXIT_SUCCESS;
}
diff --git a/tests/profiling/gatordmock/GatordMockService.cpp b/tests/profiling/gatordmock/GatordMockService.cpp
index a3f732cb55..3e19c25b6c 100644
--- a/tests/profiling/gatordmock/GatordMockService.cpp
+++ b/tests/profiling/gatordmock/GatordMockService.cpp
@@ -131,10 +131,30 @@ void GatordMockService::SendRequestCounterDir()
{
std::cout << "Sending connection acknowledgement." << std::endl;
}
- // The connection ack packet is an empty data packet with packetId == 1.
+ // The request counter directory packet is an empty data packet with packetId == 3.
SendPacket(0, 3, nullptr, 0);
}
+void GatordMockService::SendActivateTimelinePacket()
+{
+ if (m_EchoPackets)
+ {
+ std::cout << "Sending activate timeline packet." << std::endl;
+ }
+ // The activate timeline packet is an empty data packet with packetId == 6.
+ SendPacket(0, 6, nullptr, 0);
+}
+
+void GatordMockService::SendDeactivateTimelinePacket()
+{
+ if (m_EchoPackets)
+ {
+ std::cout << "Sending deactivate timeline packet." << std::endl;
+ }
+ // The deactivate timeline packet is an empty data packet with packetId == 7.
+ SendPacket(0, 7, nullptr, 0);
+}
+
bool GatordMockService::LaunchReceivingThread()
{
if (m_EchoPackets)
@@ -165,6 +185,11 @@ void GatordMockService::WaitForReceivingThread()
// Wait for the receiving thread to complete operations
m_ListeningThread.join();
}
+
+ if(m_EchoPackets)
+ {
+ m_TimelineDecoder.print();
+ }
}
void GatordMockService::SendPeriodicCounterSelectionList(uint32_t period, std::vector<uint16_t> counters)
diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp
index c00685fff2..2ff93c9de6 100644
--- a/tests/profiling/gatordmock/GatordMockService.hpp
+++ b/tests/profiling/gatordmock/GatordMockService.hpp
@@ -13,6 +13,15 @@
#include <string>
#include <thread>
+#include <TimelineDecoder.hpp>
+#include <DirectoryCaptureCommandHandler.hpp>
+#include <TimelineCaptureCommandHandler.hpp>
+#include <TimelineDirectoryCaptureCommandHandler.hpp>
+#include "PeriodicCounterCaptureCommandHandler.hpp"
+#include "StreamMetadataCommandHandler.hpp"
+
+#include "PacketVersionResolver.hpp"
+
namespace armnn
{
@@ -39,15 +48,33 @@ class GatordMockService
public:
/// @param registry reference to a command handler registry.
/// @param echoPackets if true the raw packets will be printed to stdout.
- GatordMockService(armnnUtils::Sockets::Socket clientConnection,
- armnn::profiling::CommandHandlerRegistry& registry,
- bool echoPackets)
+ GatordMockService(armnnUtils::Sockets::Socket clientConnection, bool echoPackets)
: m_ClientConnection(clientConnection)
- , m_HandlerRegistry(registry)
+ , m_PacketsReceivedCount(0)
, m_EchoPackets(echoPackets)
, m_CloseReceivingThread(false)
+ , m_PacketVersionResolver()
+ , m_HandlerRegistry()
+ , m_TimelineDecoder()
+ , m_StreamMetadataCommandHandler(
+ 0, 0, m_PacketVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true)
+ , m_CounterCaptureCommandHandler(
+ 0, 4, m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), true)
+ , m_DirectoryCaptureCommandHandler(
+ 0, 2, m_PacketVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true)
+ , m_TimelineCaptureCommandHandler(
+ 1, 1, m_PacketVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), m_TimelineDecoder)
+ , m_TimelineDirectoryCaptureCommandHandler(
+ 1, 0, m_PacketVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
+ m_TimelineCaptureCommandHandler, true)
{
- m_PacketsReceivedCount.store(0, std::memory_order_relaxed);
+ m_TimelineDecoder.SetDefaultCallbacks();
+
+ m_HandlerRegistry.RegisterFunctor(&m_StreamMetadataCommandHandler);
+ m_HandlerRegistry.RegisterFunctor(&m_CounterCaptureCommandHandler);
+ m_HandlerRegistry.RegisterFunctor(&m_DirectoryCaptureCommandHandler);
+ m_HandlerRegistry.RegisterFunctor(&m_TimelineDirectoryCaptureCommandHandler);
+ m_HandlerRegistry.RegisterFunctor(&m_TimelineCaptureCommandHandler);
}
~GatordMockService()
@@ -74,6 +101,12 @@ public:
/// Send a request counter directory packet back to the client.
void SendRequestCounterDir();
+ /// Send a activate timeline packet back to the client.
+ void SendActivateTimelinePacket();
+
+ /// Send a deactivate timeline packet back to the client.
+ void SendDeactivateTimelinePacket();
+
/// Start the thread that will receive all packets and print them nicely to stdout.
bool LaunchReceivingThread();
@@ -115,6 +148,22 @@ public:
return m_StreamMetaDataPid;
}
+ profiling::DirectoryCaptureCommandHandler& GetDirectoryCaptureCommandHandler()
+ {
+ return m_DirectoryCaptureCommandHandler;
+ }
+
+ timelinedecoder::TimelineDecoder& GetTimelineDecoder()
+ {
+ return m_TimelineDecoder;
+ }
+
+ timelinedecoder::TimelineDirectoryCaptureCommandHandler& GetTimelineDirectoryCaptureCommandHandler()
+ {
+ return m_TimelineDirectoryCaptureCommandHandler;
+ }
+
+
private:
void ReceiveLoop(GatordMockService& mockService);
@@ -141,18 +190,30 @@ private:
static const uint32_t PIPE_MAGIC = 0x45495434;
- std::atomic<uint32_t> m_PacketsReceivedCount;
TargetEndianness m_Endianness;
uint32_t m_StreamMetaDataVersion;
uint32_t m_StreamMetaDataMaxDataLen;
uint32_t m_StreamMetaDataPid;
armnnUtils::Sockets::Socket m_ClientConnection;
- armnn::profiling::CommandHandlerRegistry& m_HandlerRegistry;
+ std::atomic<uint32_t> m_PacketsReceivedCount;
bool m_EchoPackets;
std::thread m_ListeningThread;
std::atomic<bool> m_CloseReceivingThread;
+
+ profiling::PacketVersionResolver m_PacketVersionResolver;
+ profiling::CommandHandlerRegistry m_HandlerRegistry;
+
+ timelinedecoder::TimelineDecoder m_TimelineDecoder;
+
+ gatordmock::StreamMetadataCommandHandler m_StreamMetadataCommandHandler;
+ gatordmock::PeriodicCounterCaptureCommandHandler m_CounterCaptureCommandHandler;
+
+ profiling::DirectoryCaptureCommandHandler m_DirectoryCaptureCommandHandler;
+
+ timelinedecoder::TimelineCaptureCommandHandler m_TimelineCaptureCommandHandler;
+ timelinedecoder::TimelineDirectoryCaptureCommandHandler m_TimelineDirectoryCaptureCommandHandler;
};
} // namespace gatordmock
diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp
index 7d938bd404..7417946844 100644
--- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp
+++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp
@@ -9,12 +9,14 @@
#include <LabelsAndEventClasses.hpp>
#include <PeriodicCounterCaptureCommandHandler.hpp>
#include <ProfilingService.hpp>
-#include <StreamMetadataCommandHandler.hpp>
#include <TimelinePacketWriterFactory.hpp>
#include <TimelineDirectoryCaptureCommandHandler.hpp>
#include <TimelineDecoder.hpp>
+#include <Runtime.hpp>
+#include "../../src/backends/backendsCommon/test/MockBackend.hpp"
+
#include <boost/cast.hpp>
#include <boost/test/test_tools.hpp>
#include <boost/test/unit_test_suite.hpp>
@@ -104,6 +106,19 @@ BOOST_AUTO_TEST_CASE(CounterCaptureHandlingTest)
}
}
+void WaitFor(std::function<bool()> predicate, std::string errorMsg, uint32_t timeout = 2000, uint32_t sleepTime = 50)
+{
+ uint32_t timeSlept = 0;
+ while (!predicate())
+ {
+ if (timeSlept >= timeout)
+ {
+ BOOST_FAIL("Timeout: " + errorMsg);
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
+ timeSlept += sleepTime;
+ }
+}
void CheckTimelineDirectory(timelinedecoder::TimelineDirectoryCaptureCommandHandler& commandHandler)
{
@@ -211,43 +226,6 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
// The purpose of this test is to setup both sides of the profiling service and get to the point of receiving
// performance data.
- //These variables are used to wait for the profiling service
- uint32_t timeout = 2000;
- uint32_t sleepTime = 50;
- uint32_t timeSlept = 0;
-
- profiling::PacketVersionResolver packetVersionResolver;
-
- // Create the Command Handler Registry
- profiling::CommandHandlerRegistry registry;
-
- timelinedecoder::TimelineDecoder timelineDecoder;
- timelineDecoder.SetDefaultCallbacks();
-
- // Update with derived functors
- gatordmock::StreamMetadataCommandHandler streamMetadataCommandHandler(
- 0, 0, packetVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true);
-
- gatordmock::PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler(
- 0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), true);
-
- profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
- 0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true);
-
- timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler(
- 1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder);
-
- timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler(
- 1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
- timelineCaptureCommandHandler, true);
-
- // Register different derived functors
- registry.RegisterFunctor(&streamMetadataCommandHandler);
- registry.RegisterFunctor(&counterCaptureCommandHandler);
- registry.RegisterFunctor(&directoryCaptureCommandHandler);
- registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler);
- registry.RegisterFunctor(&timelineCaptureCommandHandler);
-
// Setup the mock service to bind to the UDS.
std::string udsNamespace = "gatord_namespace";
@@ -279,18 +257,15 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
BOOST_FAIL("Failed to connect client");
}
- gatordmock::GatordMockService mockService(clientSocket, registry, false);
+ gatordmock::GatordMockService mockService(clientSocket, false);
+
+ timelinedecoder::TimelineDecoder& timelineDecoder = mockService.GetTimelineDecoder();
+ profiling::DirectoryCaptureCommandHandler& directoryCaptureCommandHandler =
+ mockService.GetDirectoryCaptureCommandHandler();
// Give the profiling service sending thread time start executing and send the stream metadata.
- while (profilingService.GetCurrentState() != profiling::ProfilingState::WaitingForAck)
- {
- if (timeSlept >= timeout)
- {
- BOOST_FAIL("Timeout: Profiling service did not switch to WaitingForAck state");
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
- timeSlept += sleepTime;
- }
+ WaitFor([&](){return profilingService.GetCurrentState() == profiling::ProfilingState::WaitingForAck;},
+ "Profiling service did not switch to WaitingForAck state");
profilingService.Update();
// Read the stream metadata on the mock side.
@@ -300,55 +275,21 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
}
// Send Ack from GatorD
mockService.SendConnectionAck();
+ // And start to listen for packets
+ mockService.LaunchReceivingThread();
- timeSlept = 0;
- while (profilingService.GetCurrentState() != profiling::ProfilingState::Active)
- {
- if (timeSlept >= timeout)
- {
- BOOST_FAIL("Timeout: Profiling service did not switch to Active state");
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
- timeSlept += sleepTime;
- }
+ WaitFor([&](){return profilingService.GetCurrentState() == profiling::ProfilingState::Active;},
+ "Profiling service did not switch to Active state");
- mockService.LaunchReceivingThread();
// As part of the default startup of the profiling service a counter directory packet will be sent.
- timeSlept = 0;
- while (!directoryCaptureCommandHandler.ParsedCounterDirectory())
- {
- if (timeSlept >= timeout)
- {
- BOOST_FAIL("Timeout: MockGatord did not receive counter directory packet");
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
- timeSlept += sleepTime;
- }
+ WaitFor([&](){return directoryCaptureCommandHandler.ParsedCounterDirectory();},
+ "MockGatord did not receive counter directory packet");
- // As part of the default startup of the profiling service a counter directory packet will be sent.
- timeSlept = 0;
- while (!directoryCaptureCommandHandler.ParsedCounterDirectory())
- {
- if (timeSlept >= timeout)
- {
- BOOST_FAIL("Timeout: MockGatord did not receive counter directory packet");
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
- timeSlept += sleepTime;
- }
+ // Following that we will receive a collection of well known timeline labels and event classes
+ WaitFor([&](){return timelineDecoder.GetModel().m_EventClasses.size() >= 2;},
+ "MockGatord did not receive well known timeline labels and event classes");
- timeSlept = 0;
- while (timelineDecoder.GetModel().m_EventClasses.size() < 2)
- {
- if (timeSlept >= timeout)
- {
- BOOST_FAIL("Timeout: MockGatord did not receive well known timeline labels");
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
- timeSlept += sleepTime;
- }
-
- CheckTimelineDirectory(timelineDirectoryCaptureCommandHandler);
+ CheckTimelineDirectory(mockService.GetTimelineDirectoryCaptureCommandHandler());
// Verify the commonly used timeline packets sent when the profiling service enters the active state
CheckTimelinePackets(timelineDecoder);
@@ -439,4 +380,107 @@ BOOST_AUTO_TEST_CASE(GatorDMockEndToEnd)
// PeriodicCounterCapture data received. These are yet to be integrated.
}
+BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation)
+{
+ armnn::MockBackendInitialiser initialiser;
+ // Setup the mock service to bind to the UDS.
+ std::string udsNamespace = "gatord_namespace";
+
+ armnnUtils::Sockets::Initialize();
+ armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
+
+ if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace))
+ {
+ BOOST_FAIL("Failed to open Listening Socket");
+ }
+
+ armnn::IRuntime::CreationOptions options;
+ options.m_ProfilingOptions.m_EnableProfiling = true;
+ armnn::Runtime runtime(options);
+
+ armnnUtils::Sockets::Socket clientConnection;
+ clientConnection = armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
+ gatordmock::GatordMockService mockService(clientConnection, false);
+
+ // Read the stream metadata on the mock side.
+ if (!mockService.WaitForStreamMetaData())
+ {
+ BOOST_FAIL("Failed to receive StreamMetaData");
+ }
+
+ armnn::MockBackendProfilingService mockProfilingService = armnn::MockBackendProfilingService::Instance();
+ armnn::MockBackendProfilingContext *mockBackEndProfilingContext = mockProfilingService.GetContext();
+
+ // Send Ack from GatorD
+ mockService.SendConnectionAck();
+ // And start to listen for packets
+ mockService.LaunchReceivingThread();
+
+ // Build and optimize a simple network while we wait
+ INetworkPtr net(INetwork::Create());
+
+ IConnectableLayer* input = net->AddInputLayer(0, "input");
+
+ NormalizationDescriptor descriptor;
+ IConnectableLayer* normalize = net->AddNormalizationLayer(descriptor, "normalization");
+
+ IConnectableLayer* output = net->AddOutputLayer(0, "output");
+
+ input->GetOutputSlot(0).Connect(normalize->GetInputSlot(0));
+ normalize->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 4, 4 }, DataType::Float32));
+ normalize->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 1, 1, 4, 4 }, DataType::Float32));
+
+ std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
+ IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime.GetDeviceSpec());
+
+ WaitFor([&](){return mockService.GetDirectoryCaptureCommandHandler().ParsedCounterDirectory();},
+ "MockGatord did not receive counter directory packet");
+
+ timelinedecoder::TimelineDecoder& timelineDecoder = mockService.GetTimelineDecoder();
+
+ WaitFor([&](){return timelineDecoder.GetModel().m_EventClasses.size() >= 2;},
+ "MockGatord did not receive well known timeline labels");
+
+ // Packets we expect from SendWellKnownLabelsAndEventClassesTest
+ BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 0);
+ BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size() == 2);
+ BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size() == 10);
+ BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size() == 0);
+ BOOST_CHECK(timelineDecoder.GetModel().m_Events.size() == 0);
+
+ mockService.SendDeactivateTimelinePacket();
+
+ WaitFor([&](){return !mockBackEndProfilingContext->TimelineReportingEnabled();},
+ "Timeline packets were not deactivated");
+
+ // Load the network into runtime now that timeline reporting is disabled
+ armnn::NetworkId netId;
+ runtime.LoadNetwork(netId, std::move(optNet));
+
+ // Now activate timeline packets
+ mockService.SendActivateTimelinePacket();
+
+ WaitFor([&](){return mockBackEndProfilingContext->TimelineReportingEnabled();},
+ "Timeline packets were not activated");
+
+ // Once timeline packets have been reactivated the ActivateTimelineReportingCommandHandler will resend the
+ // SendWellKnownLabelsAndEventClasses and then send the structure of any loaded networks
+ WaitFor([&](){return timelineDecoder.GetModel().m_Labels.size() >= 24;},
+ "MockGatord did not receive well known timeline labels");
+
+ // Packets we expect from SendWellKnownLabelsAndEventClassesTest * 2 and the loaded model
+ BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 5);
+ BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size() == 4);
+ BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size() == 24);
+ BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size() == 28);
+ BOOST_CHECK(timelineDecoder.GetModel().m_Events.size() == 0);
+
+ mockService.WaitForReceivingThread();
+ armnnUtils::Sockets::Close(listeningSocket);
+
+ GetProfilingService(&runtime).Disconnect();
+}
+
BOOST_AUTO_TEST_SUITE_END()