aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2019-09-17 12:29:50 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-09-18 15:13:43 +0000
commit8355ec982eb3ff51a6a8042fe760138638ca550b (patch)
tree23150be7ae10712db7af3d6e18f1a381d945344a
parent0bd586ceb2a1e3f8132d009cf48dc46c76ae09e4 (diff)
downloadarmnn-8355ec982eb3ff51a6a8042fe760138638ca550b.tar.gz
IVGCVSW-3432 Fix a multithread store conflict
* Unit test was using the same CaptureData object across 50 threads Change-Id: I0249b5a8e0bb05e3d3efdd855f5b34b1d5ef3dc9 Signed-off-by: Jim Flynn <jim.flynn@arm.com>
-rw-r--r--src/profiling/Holder.cpp6
-rw-r--r--src/profiling/Holder.hpp4
-rw-r--r--src/profiling/test/ProfilingTests.cpp97
3 files changed, 53 insertions, 54 deletions
diff --git a/src/profiling/Holder.cpp b/src/profiling/Holder.cpp
index 9def49d22e..5916017eb6 100644
--- a/src/profiling/Holder.cpp
+++ b/src/profiling/Holder.cpp
@@ -24,7 +24,7 @@ void CaptureData::SetCapturePeriod(uint32_t capturePeriod)
m_CapturePeriod = capturePeriod;
}
-void CaptureData::SetCounterIds(std::vector<uint16_t>& counterIds)
+void CaptureData::SetCounterIds(const std::vector<uint16_t>& counterIds)
{
m_CounterIds = counterIds;
}
@@ -45,7 +45,7 @@ CaptureData Holder::GetCaptureData() const
return m_CaptureData;
}
-void Holder::SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+void Holder::SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
{
std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
m_CaptureData.SetCapturePeriod(capturePeriod);
@@ -54,4 +54,4 @@ void Holder::SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& count
} // namespace profiling
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/Holder.hpp b/src/profiling/Holder.hpp
index c22c72a929..d8d1f5bfb4 100644
--- a/src/profiling/Holder.hpp
+++ b/src/profiling/Holder.hpp
@@ -26,7 +26,7 @@ public:
CaptureData& operator= (const CaptureData& captureData);
void SetCapturePeriod(uint32_t capturePeriod);
- void SetCounterIds(std::vector<uint16_t>& counterIds);
+ void SetCounterIds(const std::vector<uint16_t>& counterIds);
uint32_t GetCapturePeriod() const;
std::vector<uint16_t> GetCounterIds() const;
@@ -41,7 +41,7 @@ public:
Holder()
: m_CaptureData() {};
CaptureData GetCaptureData() const;
- void SetCaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds);
+ void SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds);
private:
mutable std::mutex m_CaptureThreadMutex;
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 3e59d09848..51dbb07a58 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -395,7 +395,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
BOOST_TEST((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
}
-void CaptureDataWriteThreadImpl(Holder &holder, uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+void CaptureDataWriteThreadImpl(Holder& holder, uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
{
holder.SetCaptureData(capturePeriod, counterIds);
}
@@ -409,22 +409,15 @@ BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder)
{
std::map<uint32_t, std::vector<uint16_t>> periodIdMap;
std::vector<uint16_t> counterIds;
- uint16_t numThreads = 50;
- for (uint16_t i = 0; i < numThreads; ++i)
+ uint32_t numThreads = 10;
+ for (uint32_t i = 0; i < numThreads; ++i)
{
counterIds.emplace_back(i);
periodIdMap.insert(std::make_pair(i, counterIds));
}
- // Check CaptureData functions
- CaptureData capture;
- BOOST_CHECK(capture.GetCapturePeriod() == 0);
- BOOST_CHECK((capture.GetCounterIds()).empty());
- capture.SetCapturePeriod(0);
- capture.SetCounterIds(periodIdMap[0]);
- BOOST_CHECK(capture.GetCapturePeriod() == 0);
- BOOST_CHECK(capture.GetCounterIds() == periodIdMap[0]);
-
+ // Verify the read and write threads set the holder correctly
+ // and retrieve the expected values
Holder holder;
BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0);
BOOST_CHECK(((holder.GetCaptureData()).GetCounterIds()).empty());
@@ -432,76 +425,82 @@ BOOST_AUTO_TEST_CASE(CheckCaptureDataHolder)
// Check Holder functions
std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), 2, std::ref(periodIdMap[2]));
thread1.join();
-
BOOST_CHECK((holder.GetCaptureData()).GetCapturePeriod() == 2);
BOOST_CHECK((holder.GetCaptureData()).GetCounterIds() == periodIdMap[2]);
+ // NOTE: now that we have some initial values in the holder we don't have to worry
+ // in the multi-threaded section below about a read thread accessing the holder
+ // before any write thread has gotten to it so we read period = 0, counterIds empty
+ // instead of period = 0, counterIds = {0} as will the case when write thread 0
+ // has executed.
CaptureData captureData;
std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
thread2.join();
+ BOOST_CHECK(captureData.GetCapturePeriod() == 2);
BOOST_CHECK(captureData.GetCounterIds() == periodIdMap[2]);
+ std::map<uint32_t, CaptureData> captureDataIdMap;
+ for (uint32_t i = 0; i < numThreads; ++i)
+ {
+ CaptureData perThreadCaptureData;
+ captureDataIdMap.insert(std::make_pair(i, perThreadCaptureData));
+ }
+
std::vector<std::thread> threadsVect;
- for (int i = 0; i < numThreads; i+=2)
+ std::vector<std::thread> readThreadsVect;
+ for (uint32_t i = 0; i < numThreads; ++i)
{
threadsVect.emplace_back(std::thread(CaptureDataWriteThreadImpl,
std::ref(holder),
i,
- std::ref(periodIdMap[static_cast<uint16_t >(i)])));
-
- threadsVect.emplace_back(std::thread(CaptureDataReadThreadImpl,
- std::ref(holder),
- std::ref(captureData)));
+ std::ref(periodIdMap[i])));
+
+ // Verify that the CaptureData goes into the thread in a virgin state
+ BOOST_CHECK(captureDataIdMap.at(i).GetCapturePeriod() == 0);
+ BOOST_CHECK(captureDataIdMap.at(i).GetCounterIds().empty());
+ readThreadsVect.emplace_back(std::thread(CaptureDataReadThreadImpl,
+ std::ref(holder),
+ std::ref(captureDataIdMap.at(i))));
}
- for (uint16_t i = 0; i < numThreads; ++i)
+ for (uint32_t i = 0; i < numThreads; ++i)
{
threadsVect[i].join();
+ readThreadsVect[i].join();
}
- std::vector<std::thread> readThreadsVect;
- for (uint16_t i = 0; i < numThreads; ++i)
- {
- readThreadsVect.emplace_back(
- std::thread(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData)));
- }
-
- for (uint16_t i = 0; i < numThreads; ++i)
+ // Look at the CaptureData that each read thread has filled
+ // the capture period it read should match the counter ids entry
+ for (uint32_t i = 0; i < numThreads; ++i)
{
- readThreadsVect[i].join();
+ CaptureData perThreadCaptureData = captureDataIdMap.at(i);
+ BOOST_CHECK(perThreadCaptureData.GetCounterIds() == periodIdMap.at(perThreadCaptureData.GetCapturePeriod()));
}
-
- // Check CaptureData was written/read correctly from multiple threads
- std::vector<uint16_t> captureIds = captureData.GetCounterIds();
- uint32_t capturePeriod = captureData.GetCapturePeriod();
-
- BOOST_CHECK(captureIds == periodIdMap[capturePeriod]);
-
- std::vector<uint16_t> readIds = holder.GetCaptureData().GetCounterIds();
- BOOST_CHECK(captureIds == readIds);
}
BOOST_AUTO_TEST_CASE(CaptureDataMethods)
{
- // Check assignment operator
- CaptureData assignableCaptureData;
+ // Check CaptureData setter and getter functions
std::vector<uint16_t> counterIds = {42, 29, 13};
- assignableCaptureData.SetCapturePeriod(3);
- assignableCaptureData.SetCounterIds(counterIds);
+ CaptureData captureData;
+ BOOST_CHECK(captureData.GetCapturePeriod() == 0);
+ BOOST_CHECK((captureData.GetCounterIds()).empty());
+ captureData.SetCapturePeriod(150);
+ captureData.SetCounterIds(counterIds);
+ BOOST_CHECK(captureData.GetCapturePeriod() == 150);
+ BOOST_CHECK(captureData.GetCounterIds() == counterIds);
+ // Check assignment operator
CaptureData secondCaptureData;
- BOOST_CHECK(assignableCaptureData.GetCapturePeriod() == 3);
- BOOST_CHECK(assignableCaptureData.GetCounterIds() == counterIds);
-
- secondCaptureData = assignableCaptureData;
- BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 3);
+ secondCaptureData = captureData;
+ BOOST_CHECK(secondCaptureData.GetCapturePeriod() == 150);
BOOST_CHECK(secondCaptureData.GetCounterIds() == counterIds);
// Check copy constructor
- CaptureData copyConstructedCaptureData(assignableCaptureData);
+ CaptureData copyConstructedCaptureData(captureData);
- BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 3);
+ BOOST_CHECK(copyConstructedCaptureData.GetCapturePeriod() == 150);
BOOST_CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds);
}