From 68f78d8ef0134aaaf10ee4db94e808f68f1ba2a8 Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Wed, 4 Sep 2019 16:42:29 +0100 Subject: IVGCVSW-3432 Create CaptureData Holder * Create CaptureData and Holder classes * Add unit test Signed-off-by: Ellen Norris-Thompson Signed-off-by: Francis Murtagh Change-Id: I9f2766a8a6081ae4f9988904af2ca24cd434ebca --- src/profiling/Holder.cpp | 57 +++++++++++++++++++++ src/profiling/Holder.hpp | 53 +++++++++++++++++++ src/profiling/test/ProfilingTests.cpp | 95 +++++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+) create mode 100644 src/profiling/Holder.cpp create mode 100644 src/profiling/Holder.hpp (limited to 'src/profiling') diff --git a/src/profiling/Holder.cpp b/src/profiling/Holder.cpp new file mode 100644 index 0000000000..9def49d22e --- /dev/null +++ b/src/profiling/Holder.cpp @@ -0,0 +1,57 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "Holder.hpp" + +namespace armnn +{ + +namespace profiling +{ + +CaptureData& CaptureData::operator= (const CaptureData& captureData) +{ + m_CapturePeriod = captureData.m_CapturePeriod; + m_CounterIds = captureData.m_CounterIds; + + return *this; +} + +void CaptureData::SetCapturePeriod(uint32_t capturePeriod) +{ + m_CapturePeriod = capturePeriod; +} + +void CaptureData::SetCounterIds(std::vector& counterIds) +{ + m_CounterIds = counterIds; +} + +std::uint32_t CaptureData::GetCapturePeriod() const +{ + return m_CapturePeriod; +} + +std::vector CaptureData::GetCounterIds() const +{ + return m_CounterIds; +} + +CaptureData Holder::GetCaptureData() const +{ + std::lock_guard lockGuard(m_CaptureThreadMutex); + return m_CaptureData; +} + +void Holder::SetCaptureData(uint32_t capturePeriod, std::vector& counterIds) +{ + std::lock_guard lockGuard(m_CaptureThreadMutex); + m_CaptureData.SetCapturePeriod(capturePeriod); + m_CaptureData.SetCounterIds(counterIds); +} + +} // namespace profiling + +} // namespace armnn \ No newline at end of file diff --git a/src/profiling/Holder.hpp b/src/profiling/Holder.hpp new file mode 100644 index 0000000000..c22c72a929 --- /dev/null +++ b/src/profiling/Holder.hpp @@ -0,0 +1,53 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include +#include + +namespace armnn +{ + +namespace profiling +{ + +class CaptureData +{ +public: + CaptureData() + : m_CapturePeriod(0), m_CounterIds() {}; + CaptureData(uint32_t capturePeriod, std::vector& counterIds) + : m_CapturePeriod(capturePeriod), m_CounterIds(counterIds) {}; + CaptureData(const CaptureData& captureData) + : m_CapturePeriod(captureData.m_CapturePeriod), m_CounterIds(captureData.m_CounterIds) {}; + + CaptureData& operator= (const CaptureData& captureData); + + void SetCapturePeriod(uint32_t capturePeriod); + void SetCounterIds(std::vector& counterIds); + uint32_t GetCapturePeriod() const; + std::vector GetCounterIds() const; + +private: + uint32_t m_CapturePeriod; + std::vector m_CounterIds; +}; + +class Holder +{ +public: + Holder() + : m_CaptureData() {}; + CaptureData GetCaptureData() const; + void SetCaptureData(uint32_t capturePeriod, std::vector& counterIds); + +private: + mutable std::mutex m_CaptureThreadMutex; + CaptureData m_CaptureData; +}; + +} // namespace profiling + +} // namespace armnn diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index ce278abee7..c7b0bda0ff 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -7,6 +7,7 @@ #include "../CommandHandlerFunctor.hpp" #include "../CommandHandlerRegistry.hpp" #include "../EncodeVersion.hpp" +#include "../Holder.hpp" #include "../Packet.hpp" #include "../PacketVersionResolver.hpp" #include "../ProfilingStateMachine.hpp" @@ -357,4 +358,98 @@ BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine) BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected)); } +void CaptureDataWriteThreadImpl(Holder &holder, uint32_t capturePeriod, std::vector& counterIds) +{ + holder.SetCaptureData(capturePeriod, counterIds); +} + +void CaptureDataReadThreadImpl(Holder &holder, CaptureData& captureData) +{ + captureData = holder.GetCaptureData(); +} + +BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder) +{ + std::vector counterIds1 = {}; + uint32_t capturePeriod1(1); + std::vector counterIds2 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + uint32_t capturePeriod2(2); + std::vector counterIds3 = {4, 5, 5, 6}; + uint32_t capturePeriod3(3); + + // Check CaptureData functions + CaptureData capture; + BOOST_CHECK(capture.GetCapturePeriod() == 0); + BOOST_CHECK((capture.GetCounterIds()).empty()); + capture.SetCapturePeriod(capturePeriod2); + capture.SetCounterIds(counterIds2); + BOOST_CHECK(capture.GetCapturePeriod() == capturePeriod2); + BOOST_CHECK(capture.GetCounterIds() == counterIds2); + + Holder holder; + BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0); + BOOST_CHECK(((holder.GetCaptureData()).GetCounterIds()).empty()); + + // Check Holder functions + std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod3, std::ref(counterIds3)); + thread1.join(); + + BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == capturePeriod3); + BOOST_CHECK((holder.GetCaptureData()).GetCounterIds() == counterIds3); + + CaptureData captureData; + std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData)); + thread2.join(); + BOOST_CHECK(captureData.GetCounterIds() == counterIds3); + + std::thread thread3(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod2, std::ref(counterIds1)); + std::thread thread4(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData)); + std::thread thread5(CaptureDataWriteThreadImpl, std::ref(holder), capturePeriod1, std::ref(counterIds2)); + thread3.join(); + thread4.join(); + thread5.join(); + + // Check CaptureData was written/read correctly from multiple threads + std::vector captureIds = captureData.GetCounterIds(); + uint32_t capturePeriod = captureData.GetCapturePeriod(); + if (captureIds == counterIds1) + { + BOOST_CHECK(capturePeriod == capturePeriod2); + } + else if (captureIds == counterIds2) + { + BOOST_CHECK(capturePeriod == capturePeriod1); + } + else + { + BOOST_ERROR("Error in CaptureData read/write."); + } + + std::vector readIds = holder.GetCaptureData().GetCounterIds(); + BOOST_CHECK(readIds == counterIds1 || readIds == counterIds2); + + // Check assignment operator + CaptureData assignableCaptureData; + assignableCaptureData.SetCapturePeriod(capturePeriod3); + assignableCaptureData.SetCounterIds(counterIds3); + + CaptureData secondCaptureData; + secondCaptureData.SetCapturePeriod(capturePeriod2); + secondCaptureData.SetCounterIds(counterIds2); + + BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 2); + BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds2); + + secondCaptureData = assignableCaptureData; + BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 3); + BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds3); + + // Check copy constructor + CaptureData copyConstructedCaptureData(assignableCaptureData); + + BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 3); + BOOST_CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds3); + +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1