aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2022-02-04 17:50:20 +0000
committerTeresaARM <teresa.charlinreyes@arm.com>2022-02-06 19:20:11 +0000
commitec5463d394453d268706ee6213b1c6a7619d4a5d (patch)
tree98d17988da24affbe57a324241d9407ae3345f83
parent23c26277086c78704a17f0dae86da947816320c0 (diff)
downloadarmnn-ec5463d394453d268706ee6213b1c6a7619d4a5d.tar.gz
IVGCVSW-6747 Call Cl sync after EnqueueWorkload
* Add AfterEnqueueWorkload to IBackendContext * Implement AfterEnqueueWorkload in ClBackendContext to call Cl sync * Set allocated data on outputhandler only once * Handle PreImportedHandles and CurImportedId the same way as Async Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I9f59d57e298d4a494569faec3078d66af799f77b
-rw-r--r--include/armnn/backends/IBackendContext.hpp3
-rw-r--r--src/armnn/LoadedNetwork.cpp30
-rw-r--r--src/armnn/OutputHandler.cpp9
-rw-r--r--src/armnn/OutputHandler.hpp8
-rw-r--r--src/armnn/Runtime.cpp7
-rw-r--r--src/backends/cl/ClBackendContext.cpp5
-rw-r--r--src/backends/cl/ClBackendContext.hpp2
7 files changed, 53 insertions, 11 deletions
diff --git a/include/armnn/backends/IBackendContext.hpp b/include/armnn/backends/IBackendContext.hpp
index ae85b6354b..6fca42d2ed 100644
--- a/include/armnn/backends/IBackendContext.hpp
+++ b/include/armnn/backends/IBackendContext.hpp
@@ -25,6 +25,9 @@ public:
virtual bool BeforeUnloadNetwork(NetworkId networkId) = 0;
virtual bool AfterUnloadNetwork(NetworkId networkId) = 0;
+ // After Enqueue workload events
+ virtual bool AfterEnqueueWorkload(NetworkId networkId) = 0;
+
virtual ~IBackendContext() {}
};
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 45891f7dc3..48a3040b23 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -699,7 +699,7 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
for (const BindableLayer* inputLayer : graph.GetInputLayers())
{
- if (preImportedInputIds.size() != m_PreImportedInputHandles.size())
+ if (preImportedInputIds.size() > graph.GetNumInputs())
{
throw InvalidArgumentException("Invalid number of preImportedInputIds");
}
@@ -727,7 +727,7 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
for (const BindableLayer* outputLayer : graph.GetOutputLayers())
{
- if (preImportedOutputIds.size() != m_PreImportedOutputHandles.size())
+ if (preImportedOutputIds.size() > graph.GetNumOutputs())
{
throw InvalidArgumentException("Invalid number of preImportedOutputIds");
}
@@ -770,11 +770,6 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
}
}
}
- // Clear m_PreImportedInputHandles and m_PreImportedOutputHandles
- m_PreImportedInputHandles.clear();
- m_PreImportedOutputHandles.clear();
- m_CurImportedInputId = 0;
- m_CurImportedOutputId = 0;
std::unique_ptr<TimelineUtilityMethods> timelineUtils =
TimelineUtilityMethods::GetTimelineUtils(m_ProfilingService);
@@ -1271,6 +1266,16 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp
{
// Cannot import, use allocated data
handler.UseAllocatedData();
+ // Ensure that the workload get correct tensor
+ try
+ {
+ m_WorkloadQueue[m_InputWorkloadSlotPairs[layerBindingId].first].get()->ReplaceInputTensorHandle(
+ handler.GetData(), m_InputWorkloadSlotPairs[layerBindingId].second);
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ IgnoreUnused(e);
+ }
}
}
@@ -1437,6 +1442,17 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors&
{
// Cannot import, use allocated memory
outputHandler.UseAllocatedData();
+ // Ensure that the workload get correct tensor
+ try
+ {
+ m_WorkloadQueue[m_OutputWorkloadSlotPairs[layerBindingId].first].get()->
+ ReplaceOutputTensorHandle(outputHandler.GetData(),
+ m_OutputWorkloadSlotPairs[layerBindingId].second);
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ IgnoreUnused(e);
+ }
}
}
return importedOutputs;
diff --git a/src/armnn/OutputHandler.cpp b/src/armnn/OutputHandler.cpp
index 807262e482..8827d8ac58 100644
--- a/src/armnn/OutputHandler.cpp
+++ b/src/armnn/OutputHandler.cpp
@@ -35,4 +35,13 @@ void OutputHandler::CollectWorkloadOutputs(WorkloadDataCollector& dataCollector)
dataCollector.Push(m_TensorHandle.get(), m_TensorInfo);
}
+void OutputHandler::SetAllocatedData()
+{
+ // Set allocated data only once
+ if (!m_AllocatedTensorHandle)
+ {
+ m_AllocatedTensorHandle = std::move(m_TensorHandle);
+ }
+}
+
} // namespace armnn
diff --git a/src/armnn/OutputHandler.hpp b/src/armnn/OutputHandler.hpp
index 3fd2519ed5..d1cb2deea3 100644
--- a/src/armnn/OutputHandler.hpp
+++ b/src/armnn/OutputHandler.hpp
@@ -50,15 +50,15 @@ public:
void SetData(std::unique_ptr<ITensorHandle> data) { m_TensorHandle = std::move(data); }
- void SetAllocatedData() { m_AllocatedTensorHandle = std::move(m_TensorHandle); }
+ void SetAllocatedData();
- void UseAllocatedData() { m_TensorHandle = std::move(m_AllocatedTensorHandle); }
+ void UseAllocatedData() { m_TensorHandle = m_AllocatedTensorHandle; }
/// @brief Returns true if SetTensorInfo() has been called at least once on this.
bool IsTensorInfoSet() const { return m_bTensorInfoSet; }
private:
- std::unique_ptr<ITensorHandle> m_TensorHandle;
- std::unique_ptr<ITensorHandle> m_AllocatedTensorHandle;
+ std::shared_ptr<ITensorHandle> m_TensorHandle;
+ std::shared_ptr<ITensorHandle> m_AllocatedTensorHandle;
TensorInfo m_TensorInfo;
bool m_bTensorInfoSet = false;
};
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 95fb8a3abb..1abe0f394b 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -242,6 +242,7 @@ Status RuntimeImpl::UnloadNetwork(NetworkId networkId)
profiling::LabelsAndEventClasses::ARMNN_PROFILING_EOL_EVENT_CLASS);
}
}
+
if (m_LoadedNetworks.erase(networkId) == 0)
{
ARMNN_LOG(warning) << "WARNING: RuntimeImpl::UnloadNetwork(): " << networkId << " not found!";
@@ -632,6 +633,12 @@ Status RuntimeImpl::EnqueueWorkload(NetworkId networkId,
ARMNN_LOG(info) << "Execution time: " << std::setprecision(2)
<< std::fixed << armnn::GetTimeDuration(startTime).count() << " ms.";
+ // Call After EnqueueWorkload events
+ for (auto&& context : m_BackendContexts)
+ {
+ context.second->AfterEnqueueWorkload(networkId);
+ }
+
return status;
}
diff --git a/src/backends/cl/ClBackendContext.cpp b/src/backends/cl/ClBackendContext.cpp
index 9c5cca9d3a..5358fe9c79 100644
--- a/src/backends/cl/ClBackendContext.cpp
+++ b/src/backends/cl/ClBackendContext.cpp
@@ -285,6 +285,11 @@ bool ClBackendContext::AfterUnloadNetwork(NetworkId networkId)
return true;
}
+bool ClBackendContext::AfterEnqueueWorkload(NetworkId)
+{
+ return m_ClContextControlWrapper->Sync();
+}
+
ClBackendContext::~ClBackendContext()
{
if (m_Tuner && !m_TuningFile.empty())
diff --git a/src/backends/cl/ClBackendContext.hpp b/src/backends/cl/ClBackendContext.hpp
index af988a96dc..659d47b7c2 100644
--- a/src/backends/cl/ClBackendContext.hpp
+++ b/src/backends/cl/ClBackendContext.hpp
@@ -25,6 +25,8 @@ public:
bool BeforeUnloadNetwork(NetworkId networkId) override;
bool AfterUnloadNetwork(NetworkId networkId) override;
+ bool AfterEnqueueWorkload(NetworkId networkId) override;
+
~ClBackendContext() override;
private: