diff options
author | Jim Flynn <jim.flynn@arm.com> | 2019-09-17 12:29:50 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-09-18 15:13:43 +0000 |
commit | 8355ec982eb3ff51a6a8042fe760138638ca550b (patch) | |
tree | 23150be7ae10712db7af3d6e18f1a381d945344a | |
parent | 0bd586ceb2a1e3f8132d009cf48dc46c76ae09e4 (diff) | |
download | armnn-8355ec982eb3ff51a6a8042fe760138638ca550b.tar.gz |
IVGCVSW-3432 Fix a multithread store conflict
* Unit test was using the same CaptureData object across 50 threads
Change-Id: I0249b5a8e0bb05e3d3efdd855f5b34b1d5ef3dc9
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
-rw-r--r-- | src/profiling/Holder.cpp | 6 | ||||
-rw-r--r-- | src/profiling/Holder.hpp | 4 | ||||
-rw-r--r-- | src/profiling/test/ProfilingTests.cpp | 97 |
3 files changed, 53 insertions, 54 deletions
diff --git a/src/profiling/Holder.cpp b/src/profiling/Holder.cpp index 9def49d22e..5916017eb6 100644 --- a/src/profiling/Holder.cpp +++ b/src/profiling/Holder.cpp @@ -24,7 +24,7 @@ void CaptureData::SetCapturePeriod(uint32_t capturePeriod) m_CapturePeriod = capturePeriod; } -void CaptureData::SetCounterIds(std::vector<uint16_t>& counterIds) +void CaptureData::SetCounterIds(const std::vector<uint16_t>& counterIds) { m_CounterIds = counterIds; } @@ -45,7 +45,7 @@ CaptureData Holder::GetCaptureData() const return m_CaptureData; } -void Holder::SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds) +void Holder::SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds) { std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex); m_CaptureData.SetCapturePeriod(capturePeriod); @@ -54,4 +54,4 @@ void Holder::SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& count } // namespace profiling -} // namespace armnn
\ No newline at end of file +} // namespace armnn diff --git a/src/profiling/Holder.hpp b/src/profiling/Holder.hpp index c22c72a929..d8d1f5bfb4 100644 --- a/src/profiling/Holder.hpp +++ b/src/profiling/Holder.hpp @@ -26,7 +26,7 @@ public: CaptureData& operator= (const CaptureData& captureData); void SetCapturePeriod(uint32_t capturePeriod); - void SetCounterIds(std::vector<uint16_t>& counterIds); + void SetCounterIds(const std::vector<uint16_t>& counterIds); uint32_t GetCapturePeriod() const; std::vector<uint16_t> GetCounterIds() const; @@ -41,7 +41,7 @@ public: Holder() : m_CaptureData() {}; CaptureData GetCaptureData() const; - void SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds); + void SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds); private: mutable std::mutex m_CaptureThreadMutex; diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp index 3e59d09848..51dbb07a58 100644 --- a/src/profiling/test/ProfilingTests.cpp +++ b/src/profiling/test/ProfilingTests.cpp @@ -395,7 +395,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine) BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected)); } -void CaptureDataWriteThreadImpl(Holder &holder, uint32_t capturePeriod, std::vector<uint16_t>& counterIds) +void CaptureDataWriteThreadImpl(Holder& holder, uint32_t capturePeriod, const std::vector<uint16_t>& counterIds) { holder.SetCaptureData(capturePeriod, counterIds); } @@ -409,22 +409,15 @@ BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder) { std::map<uint32_t, std::vector<uint16_t>> periodIdMap; std::vector<uint16_t> counterIds; - uint16_t numThreads = 50; - for (uint16_t i = 0; i < numThreads; ++i) + uint32_t numThreads = 10; + for (uint32_t i = 0; i < numThreads; ++i) { counterIds.emplace_back(i); periodIdMap.insert(std::make_pair(i, counterIds)); } - // Check CaptureData functions - CaptureData capture; - BOOST_CHECK(capture.GetCapturePeriod() == 0); - BOOST_CHECK((capture.GetCounterIds()).empty()); - capture.SetCapturePeriod(0); - capture.SetCounterIds(periodIdMap[0]); - BOOST_CHECK(capture.GetCapturePeriod() == 0); - BOOST_CHECK(capture.GetCounterIds() == periodIdMap[0]); - + // Verify the read and write threads set the holder correctly + // and retrieve the expected values Holder holder; BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0); BOOST_CHECK(((holder.GetCaptureData()).GetCounterIds()).empty()); @@ -432,76 +425,82 @@ BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder) // Check Holder functions std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), 2, std::ref(periodIdMap[2])); thread1.join(); - BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 2); BOOST_CHECK((holder.GetCaptureData()).GetCounterIds() == periodIdMap[2]); + // NOTE: now that we have some initial values in the holder we don't have to worry + // in the multi-threaded section below about a read thread accessing the holder + // before any write thread has gotten to it so we read period = 0, counterIds empty + // instead of period = 0, counterIds = {0} as will the case when write thread 0 + // has executed. CaptureData captureData; std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData)); thread2.join(); + BOOST_CHECK(captureData.GetCapturePeriod() == 2); BOOST_CHECK(captureData.GetCounterIds() == periodIdMap[2]); + std::map<uint32_t, CaptureData> captureDataIdMap; + for (uint32_t i = 0; i < numThreads; ++i) + { + CaptureData perThreadCaptureData; + captureDataIdMap.insert(std::make_pair(i, perThreadCaptureData)); + } + std::vector<std::thread> threadsVect; - for (int i = 0; i < numThreads; i+=2) + std::vector<std::thread> readThreadsVect; + for (uint32_t i = 0; i < numThreads; ++i) { threadsVect.emplace_back(std::thread(CaptureDataWriteThreadImpl, std::ref(holder), i, - std::ref(periodIdMap[static_cast<uint16_t >(i)]))); - - threadsVect.emplace_back(std::thread(CaptureDataReadThreadImpl, - std::ref(holder), - std::ref(captureData))); + std::ref(periodIdMap[i]))); + + // Verify that the CaptureData goes into the thread in a virgin state + BOOST_CHECK(captureDataIdMap.at(i).GetCapturePeriod() == 0); + BOOST_CHECK(captureDataIdMap.at(i).GetCounterIds().empty()); + readThreadsVect.emplace_back(std::thread(CaptureDataReadThreadImpl, + std::ref(holder), + std::ref(captureDataIdMap.at(i)))); } - for (uint16_t i = 0; i < numThreads; ++i) + for (uint32_t i = 0; i < numThreads; ++i) { threadsVect[i].join(); + readThreadsVect[i].join(); } - std::vector<std::thread> readThreadsVect; - for (uint16_t i = 0; i < numThreads; ++i) - { - readThreadsVect.emplace_back( - std::thread(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData))); - } - - for (uint16_t i = 0; i < numThreads; ++i) + // Look at the CaptureData that each read thread has filled + // the capture period it read should match the counter ids entry + for (uint32_t i = 0; i < numThreads; ++i) { - readThreadsVect[i].join(); + CaptureData perThreadCaptureData = captureDataIdMap.at(i); + BOOST_CHECK(perThreadCaptureData.GetCounterIds() == periodIdMap.at(perThreadCaptureData.GetCapturePeriod())); } - - // Check CaptureData was written/read correctly from multiple threads - std::vector<uint16_t> captureIds = captureData.GetCounterIds(); - uint32_t capturePeriod = captureData.GetCapturePeriod(); - - BOOST_CHECK(captureIds == periodIdMap[capturePeriod]); - - std::vector<uint16_t> readIds = holder.GetCaptureData().GetCounterIds(); - BOOST_CHECK(captureIds == readIds); } BOOST_AUTO_TEST_CASE(CaptureDataMethods) { - // Check assignment operator - CaptureData assignableCaptureData; + // Check CaptureData setter and getter functions std::vector<uint16_t> counterIds = {42, 29, 13}; - assignableCaptureData.SetCapturePeriod(3); - assignableCaptureData.SetCounterIds(counterIds); + CaptureData captureData; + BOOST_CHECK(captureData.GetCapturePeriod() == 0); + BOOST_CHECK((captureData.GetCounterIds()).empty()); + captureData.SetCapturePeriod(150); + captureData.SetCounterIds(counterIds); + BOOST_CHECK(captureData.GetCapturePeriod() == 150); + BOOST_CHECK(captureData.GetCounterIds() == counterIds); + // Check assignment operator CaptureData secondCaptureData; - BOOST_CHECK(assignableCaptureData.GetCapturePeriod() == 3); - BOOST_CHECK(assignableCaptureData.GetCounterIds() == counterIds); - - secondCaptureData = assignableCaptureData; - BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 3); + secondCaptureData = captureData; + BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 150); BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds); // Check copy constructor - CaptureData copyConstructedCaptureData(assignableCaptureData); + CaptureData copyConstructedCaptureData(captureData); - BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 3); + BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 150); BOOST_CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds); } |