aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2018-10-02 15:52:46 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:54 +0100
commit03614f697396558a652f22c6efac2a3cd1f71460 (patch)
tree07955162c6031fec3817de65949201b38360bd61
parentc26ba759fe67bd14829a84b5abac80f51ca61946 (diff)
downloadarmnn-03614f697396558a652f22c6efac2a3cd1f71460.tar.gz
IVGCVSW-1823 *Free working mem only when network changes
Change-Id: I62b34713f8ebd96b9d4369f25cc8ba474aad8bb4
-rw-r--r--src/armnn/LoadedNetwork.cpp69
-rw-r--r--src/armnn/LoadedNetwork.hpp20
-rw-r--r--src/armnn/Runtime.cpp12
-rw-r--r--src/armnn/Runtime.hpp11
4 files changed, 88 insertions, 24 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index f49fa7b878..7aa66d9b09 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -72,6 +72,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net)
: m_CpuRef()
, m_OptimizedNetwork(std::move(net))
+ , m_WorkingMemLock(m_WorkingMemMutex, std::defer_lock)
{
// Create a profiler and register it for the current thread.
m_Profiler = std::make_shared<Profiler>();
@@ -303,6 +304,8 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
}
// For each input to the network, call EnqueueInput with the data passed by the user.
+ m_InputQueue.clear();
+ m_InputQueue.reserve(graph.GetNumInputs());
for (const BindableLayer* inputLayer : graph.GetInputLayers())
{
const TensorPin& pin = workloadData.GetInputTensorPin(inputLayer->GetBindingId());
@@ -310,6 +313,8 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
}
// For each output to the network, call EnqueueOutput with the data passed by the user.
+ m_OutputQueue.clear();
+ m_OutputQueue.reserve(graph.GetNumOutputs());
for (const BindableLayer* outputLayer : graph.GetOutputLayers())
{
const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId());
@@ -324,9 +329,6 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
executionSucceeded = Execute();
}
- // Hack: get rid of inputs and outputs we added.
- TidyWorkloadQueue(graph.GetNumInputs(), graph.GetNumOutputs());
-
return executionSucceeded ? Status::Success : Status::Failure;
}
@@ -360,7 +362,7 @@ void LoadedNetwork::EnqueueInput(const BindableLayer& layer, ITensorHandle* tens
const IWorkloadFactory& workloadFactory = GetWorkloadFactory(layer);
auto inputWorkload = workloadFactory.CreateInput(inputQueueDescriptor, info);
BOOST_ASSERT_MSG(inputWorkload, "No input workload created");
- m_WorkloadQueue.insert(m_WorkloadQueue.begin(), move(inputWorkload));
+ m_InputQueue.push_back(move(inputWorkload));
}
void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo)
@@ -396,16 +398,39 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten
const IWorkloadFactory& workloadFactory = GetWorkloadFactory(layer);
auto outputWorkload = workloadFactory.CreateOutput(outputQueueDescriptor, info);
BOOST_ASSERT_MSG(outputWorkload, "No output workload created");
- m_WorkloadQueue.push_back(move(outputWorkload));
+ m_OutputQueue.push_back(move(outputWorkload));
}
-bool LoadedNetwork::Execute()
+void LoadedNetwork::AllocateWorkingMemory()
{
- bool success = true;
-
+ BOOST_ASSERT_MSG(m_WorkingMemLock.owns_lock(), "Cannot allocate working memory if mutex is not already locked.");
+ if (m_IsWorkingMemAllocated)
+ {
+ return;
+ }
m_CpuRef.Acquire();
m_CpuAcc.Acquire();
m_GpuAcc.Acquire();
+ m_IsWorkingMemAllocated = true;
+}
+
+void LoadedNetwork::FreeWorkingMemory()
+{
+ std::lock_guard<UniqueMutexLock> lockGuard(m_WorkingMemLock);
+ if (!m_IsWorkingMemAllocated)
+ {
+ return;
+ }
+ // Informs the memory managers to release memory in it's respective memory group
+ m_CpuRef.Release();
+ m_CpuAcc.Release();
+ m_GpuAcc.Release();
+ m_IsWorkingMemAllocated = false;
+}
+
+bool LoadedNetwork::Execute()
+{
+ bool success = true;
auto Fail = [&](const std::exception& error)
{
@@ -415,9 +440,22 @@ bool LoadedNetwork::Execute()
try
{
- for (size_t i = 0; i < m_WorkloadQueue.size(); ++i)
+ std::lock_guard<UniqueMutexLock> lockGuard(m_WorkingMemLock);
+ AllocateWorkingMemory();
+
+ for (auto& input : m_InputQueue)
+ {
+ input->Execute();
+ }
+
+ for (auto& workload : m_WorkloadQueue)
+ {
+ workload->Execute();
+ }
+
+ for (auto& output: m_OutputQueue)
{
- m_WorkloadQueue[i]->Execute();
+ output->Execute();
}
}
catch (const RuntimeException& error)
@@ -429,18 +467,7 @@ bool LoadedNetwork::Execute()
Fail(error);
}
- // Informs the memory managers to release memory in it's respective memory group
- m_CpuRef.Release();
- m_CpuAcc.Release();
- m_GpuAcc.Release();
-
return success;
}
-void LoadedNetwork::TidyWorkloadQueue(size_t numInputs, size_t numOutputs)
-{
- m_WorkloadQueue.erase(m_WorkloadQueue.begin(), m_WorkloadQueue.begin() + boost::numeric_cast<long>(numInputs));
- m_WorkloadQueue.erase(m_WorkloadQueue.end() - boost::numeric_cast<long>(numOutputs), m_WorkloadQueue.end());
-}
-
}
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index 7808cc19e3..3deb8bc2e2 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -17,6 +17,8 @@
#include <backends/Workload.hpp>
#include <backends/WorkloadFactory.hpp>
+#include <mutex>
+
namespace cl
{
class Context;
@@ -30,6 +32,9 @@ namespace armnn
class LoadedNetwork
{
public:
+ using WorkloadQueue = std::vector< std::unique_ptr<IWorkload> >;
+ ~LoadedNetwork(){ FreeWorkingMemory(); }
+
TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
@@ -43,6 +48,9 @@ public:
// the shared_ptr's reference counter
const std::shared_ptr<Profiler>& GetProfiler() const { return m_Profiler; }
+ void AllocateWorkingMemory();
+ void FreeWorkingMemory();
+
private:
LoadedNetwork(std::unique_ptr<OptimizedNetwork> net);
@@ -52,8 +60,6 @@ private:
bool Execute();
- void TidyWorkloadQueue(size_t numInputs, size_t numOutputs);
-
const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
RefWorkloadFactory m_CpuRef;
@@ -61,8 +67,16 @@ private:
ClWorkloadFactory m_GpuAcc;
std::unique_ptr<OptimizedNetwork> m_OptimizedNetwork;
- std::vector< std::unique_ptr<IWorkload> > m_WorkloadQueue;
+ WorkloadQueue m_InputQueue;
+ WorkloadQueue m_WorkloadQueue;
+ WorkloadQueue m_OutputQueue;
std::shared_ptr<Profiler> m_Profiler;
+
+ using UniqueMutexLock = std::unique_lock<std::mutex>;
+ mutable std::mutex m_WorkingMemMutex;
+ UniqueMutexLock m_WorkingMemLock;
+
+ bool m_IsWorkingMemAllocated=false;
};
}
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 8a7023ed76..e84cbe0a60 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -195,11 +195,23 @@ TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId laye
return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
}
+
Status Runtime::EnqueueWorkload(NetworkId networkId,
const InputTensors& inputTensors,
const OutputTensors& outputTensors)
{
LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+
+ static thread_local NetworkId lastId = networkId;
+ if (lastId != networkId)
+ {
+ LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
+ {
+ network->FreeWorkingMemory();
+ });
+ }
+ lastId=networkId;
+
return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
}
diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp
index 12211f4e04..29bb6808d8 100644
--- a/src/armnn/Runtime.hpp
+++ b/src/armnn/Runtime.hpp
@@ -73,6 +73,17 @@ private:
LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
+ template<typename Func>
+ void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
+ {
+ std::lock_guard<std::mutex> lockGuard(m_Mutex);
+ auto iter = m_LoadedNetworks.find(networkId);
+ if (iter != m_LoadedNetworks.end())
+ {
+ f(iter->second.get());
+ }
+ }
+
mutable std::mutex m_Mutex;
std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;