From 8355ec982eb3ff51a6a8042fe760138638ca550b Mon Sep 17 00:00:00 2001 From: Jim Flynn Date: Tue, 17 Sep 2019 12:29:50 +0100 Subject: IVGCVSW-3432 Fix a multithread store conflict * Unit test was using the same CaptureData object across 50 threads Change-Id: I0249b5a8e0bb05e3d3efdd855f5b34b1d5ef3dc9 Signed-off-by: Jim Flynn --- src/profiling/Holder.cpp | 6 +-- src/profiling/Holder.hpp | 4 +- src/profiling/test/ProfilingTests.cpp | 97 +++++++++++++++++------------------ 3 files changed, 53 insertions(+), 54 deletions(-) (limited to 'src/profiling') diff --git a/src/profiling/Holder.cpp b/src/profiling/Holder.cpp index 9def49d22e..5916017eb6 100644 --- a/src/profiling/Holder.cpp +++ b/src/profiling/Holder.cpp @@ -24,7 +24,7 @@ void CaptureData::SetCapturePeriod(uint32_t capturePeriod) m_CapturePeriod = capturePeriod; } -void CaptureData::SetCounterIds(std::vector& counterIds) +void CaptureData::SetCounterIds(const std::vector& counterIds) { m_CounterIds = counterIds; } @@ -45,7 +45,7 @@ CaptureData Holder::GetCaptureData() const return m_CaptureData; } -void Holder::SetCaptureData(uint32_t capturePeriod, std::vector& counterIds) +void Holder::SetCaptureData(uint32_t capturePeriod, const std::vector& counterIds) { std::lock_guard lockGuard(m_CaptureThreadMutex); m_CaptureData.SetCapturePeriod(capturePeriod); @@ -54,4 +54,4 @@ void Holder::SetCaptureData(uint32_t capturePeriod, std::vector& count } // namespace profiling -} // namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/profiling/Holder.hpp b/src/profiling/Holder.hpp index c22c72a929..d8d1f5bfb4 100644 --- a/src/profiling/Holder.hpp +++ b/src/profiling/Holder.hpp @@ -26,7 +26,7 @@ public: CaptureData& operator= (const CaptureData& captureData); void SetCapturePeriod(uint32_t capturePeriod); - void SetCounterIds(std::vector& counterIds); + void SetCounterIds(const std::vector& counterIds); uint32_t GetCapturePeriod() const; std::vector GetCounterIds() const; @@ -41,7 +41,7 @@ public: Holder() : m_CaptureData() {}; CaptureData GetCaptureData() const; - void SetCaptureData(uint32_t capturePeriod, std::vector& counterIds); + void SetCaptureData(uint32_t capturePeriod, const std::vector& counterIds); private: mutable std::mutex m_CaptureThreadMutex; diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 3e59d09848..51dbb07a58 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -395,7 +395,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine) BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected)); } -void CaptureDataWriteThreadImpl(Holder &holder, uint32_t capturePeriod, std::vector& counterIds) +void CaptureDataWriteThreadImpl(Holder& holder, uint32_t capturePeriod, const std::vector& counterIds) { holder.SetCaptureData(capturePeriod, counterIds); } @@ -409,22 +409,15 @@ BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder) { std::map> periodIdMap; std::vector counterIds; - uint16_t numThreads = 50; - for (uint16_t i = 0; i < numThreads; ++i) + uint32_t numThreads = 10; + for (uint32_t i = 0; i < numThreads; ++i) { counterIds.emplace_back(i); periodIdMap.insert(std::make_pair(i, counterIds)); } - // Check CaptureData functions - CaptureData capture; - BOOST_CHECK(capture.GetCapturePeriod() == 0); - BOOST_CHECK((capture.GetCounterIds()).empty()); - capture.SetCapturePeriod(0); - capture.SetCounterIds(periodIdMap[0]); - BOOST_CHECK(capture.GetCapturePeriod() == 0); - BOOST_CHECK(capture.GetCounterIds() == periodIdMap[0]); - + // Verify the read and write threads set the holder correctly + // and retrieve the expected values Holder holder; BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0); BOOST_CHECK(((holder.GetCaptureData()).GetCounterIds()).empty()); @@ -432,76 +425,82 @@ BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder) // Check Holder functions std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), 2, std::ref(periodIdMap[2])); thread1.join(); - BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 2); BOOST_CHECK((holder.GetCaptureData()).GetCounterIds() == periodIdMap[2]); + // NOTE: now that we have some initial values in the holder we don't have to worry + // in the multi-threaded section below about a read thread accessing the holder + // before any write thread has gotten to it so we read period = 0, counterIds empty + // instead of period = 0, counterIds = {0} as will the case when write thread 0 + // has executed. CaptureData captureData; std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData)); thread2.join(); + BOOST_CHECK(captureData.GetCapturePeriod() == 2); BOOST_CHECK(captureData.GetCounterIds() == periodIdMap[2]); + std::map captureDataIdMap; + for (uint32_t i = 0; i < numThreads; ++i) + { + CaptureData perThreadCaptureData; + captureDataIdMap.insert(std::make_pair(i, perThreadCaptureData)); + } + std::vector threadsVect; - for (int i = 0; i < numThreads; i+=2) + std::vector readThreadsVect; + for (uint32_t i = 0; i < numThreads; ++i) { threadsVect.emplace_back(std::thread(CaptureDataWriteThreadImpl, std::ref(holder), i, - std::ref(periodIdMap[static_cast(i)]))); - - threadsVect.emplace_back(std::thread(CaptureDataReadThreadImpl, - std::ref(holder), - std::ref(captureData))); + std::ref(periodIdMap[i]))); + + // Verify that the CaptureData goes into the thread in a virgin state + BOOST_CHECK(captureDataIdMap.at(i).GetCapturePeriod() == 0); + BOOST_CHECK(captureDataIdMap.at(i).GetCounterIds().empty()); + readThreadsVect.emplace_back(std::thread(CaptureDataReadThreadImpl, + std::ref(holder), + std::ref(captureDataIdMap.at(i)))); } - for (uint16_t i = 0; i < numThreads; ++i) + for (uint32_t i = 0; i < numThreads; ++i) { threadsVect[i].join(); + readThreadsVect[i].join(); } - std::vector readThreadsVect; - for (uint16_t i = 0; i < numThreads; ++i) - { - readThreadsVect.emplace_back( - std::thread(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData))); - } - - for (uint16_t i = 0; i < numThreads; ++i) + // Look at the CaptureData that each read thread has filled + // the capture period it read should match the counter ids entry + for (uint32_t i = 0; i < numThreads; ++i) { - readThreadsVect[i].join(); + CaptureData perThreadCaptureData = captureDataIdMap.at(i); + BOOST_CHECK(perThreadCaptureData.GetCounterIds() == periodIdMap.at(perThreadCaptureData.GetCapturePeriod())); } - - // Check CaptureData was written/read correctly from multiple threads - std::vector captureIds = captureData.GetCounterIds(); - uint32_t capturePeriod = captureData.GetCapturePeriod(); - - BOOST_CHECK(captureIds == periodIdMap[capturePeriod]); - - std::vector readIds = holder.GetCaptureData().GetCounterIds(); - BOOST_CHECK(captureIds == readIds); } BOOST_AUTO_TEST_CASE(CaptureDataMethods) { - // Check assignment operator - CaptureData assignableCaptureData; + // Check CaptureData setter and getter functions std::vector counterIds = {42, 29, 13}; - assignableCaptureData.SetCapturePeriod(3); - assignableCaptureData.SetCounterIds(counterIds); + CaptureData captureData; + BOOST_CHECK(captureData.GetCapturePeriod() == 0); + BOOST_CHECK((captureData.GetCounterIds()).empty()); + captureData.SetCapturePeriod(150); + captureData.SetCounterIds(counterIds); + BOOST_CHECK(captureData.GetCapturePeriod() == 150); + BOOST_CHECK(captureData.GetCounterIds() == counterIds); + // Check assignment operator CaptureData secondCaptureData; - BOOST_CHECK(assignableCaptureData.GetCapturePeriod() == 3); - BOOST_CHECK(assignableCaptureData.GetCounterIds() == counterIds); - - secondCaptureData = assignableCaptureData; - BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 3); + secondCaptureData = captureData; + BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 150); BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds); // Check copy constructor - CaptureData copyConstructedCaptureData(assignableCaptureData); + CaptureData copyConstructedCaptureData(captureData); - BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 3); + BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 150); BOOST_CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds); } -- cgit v1.2.1