aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r--src/armnn/LoadedNetwork.cpp30
1 files changed, 23 insertions, 7 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 45891f7dc3..48a3040b23 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -699,7 +699,7 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
for (const BindableLayer* inputLayer : graph.GetInputLayers())
{
- if (preImportedInputIds.size() != m_PreImportedInputHandles.size())
+ if (preImportedInputIds.size() > graph.GetNumInputs())
{
throw InvalidArgumentException("Invalid number of preImportedInputIds");
}
@@ -727,7 +727,7 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
for (const BindableLayer* outputLayer : graph.GetOutputLayers())
{
- if (preImportedOutputIds.size() != m_PreImportedOutputHandles.size())
+ if (preImportedOutputIds.size() > graph.GetNumOutputs())
{
throw InvalidArgumentException("Invalid number of preImportedOutputIds");
}
@@ -770,11 +770,6 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
}
}
}
- // Clear m_PreImportedInputHandles and m_PreImportedOutputHandles
- m_PreImportedInputHandles.clear();
- m_PreImportedOutputHandles.clear();
- m_CurImportedInputId = 0;
- m_CurImportedOutputId = 0;
std::unique_ptr<TimelineUtilityMethods> timelineUtils =
TimelineUtilityMethods::GetTimelineUtils(m_ProfilingService);
@@ -1271,6 +1266,16 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp
{
// Cannot import, use allocated data
handler.UseAllocatedData();
+ // Ensure that the workload get correct tensor
+ try
+ {
+ m_WorkloadQueue[m_InputWorkloadSlotPairs[layerBindingId].first].get()->ReplaceInputTensorHandle(
+ handler.GetData(), m_InputWorkloadSlotPairs[layerBindingId].second);
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ IgnoreUnused(e);
+ }
}
}
@@ -1437,6 +1442,17 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors&
{
// Cannot import, use allocated memory
outputHandler.UseAllocatedData();
+ // Ensure that the workload get correct tensor
+ try
+ {
+ m_WorkloadQueue[m_OutputWorkloadSlotPairs[layerBindingId].first].get()->
+ ReplaceOutputTensorHandle(outputHandler.GetData(),
+ m_OutputWorkloadSlotPairs[layerBindingId].second);
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ IgnoreUnused(e);
+ }
}
}
return importedOutputs;