aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r--src/armnn/LoadedNetwork.cpp28
1 files changed, 22 insertions, 6 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 3f4aa34a5b..3d84054b69 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -955,10 +955,10 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
syncDesc.m_Inputs.push_back(inputTensorHandle);
WorkloadInfo info;
info.m_InputTensorInfos.push_back(
- outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo());
+ outputLayer->GetInputSlot(0).GetTensorInfo());
auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
- m_OutputQueue.push_back(move(syncWorkload));
+ m_OutputQueue.push_back(std::move(syncWorkload));
importedOutputIdIndex++;
}
else
@@ -1089,7 +1089,7 @@ void LoadedNetwork::EnqueueInput(const BindableLayer& layer, ITensorHandle* tens
timelineUtils->Commit();
}
- m_InputQueue.push_back(move(inputWorkload));
+ m_InputQueue.push_back(std::move(inputWorkload));
}
}
@@ -1149,7 +1149,7 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten
info.m_InputTensorInfos.push_back(inputTensorInfo);
auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
- m_OutputQueue.push_back(move(syncWorkload));
+ m_OutputQueue.push_back(std::move(syncWorkload));
}
else
{
@@ -1177,7 +1177,7 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten
timelineUtils->Commit();
}
- m_OutputQueue.push_back(move(outputWorkload));
+ m_OutputQueue.push_back(std::move(outputWorkload));
}
}
@@ -1650,7 +1650,7 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors&
const InputSlot& inputSlot = layer->GetInputSlots()[0];
ITensorHandleFactory::FactoryId factoryId = inputSlot.GetConnectedOutputSlot()->GetTensorHandleFactoryId();
- const TensorInfo& tensorInfo = inputSlot.GetConnectedOutputSlot()->GetTensorInfo();
+ const TensorInfo& tensorInfo = inputSlot.GetTensorInfo();
ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
ARMNN_ASSERT(handleFactory);
@@ -2093,6 +2093,14 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network
if (found != m_ConstantTensorHandles.end())
{
ITensorHandle* tensorHandle = found->second;
+ if (slot.IsTensorInfoOverridden())
+ {
+ ITensorHandle* decorated = tensorHandle->DecorateTensorHandle(slot.GetTensorInfo()).get();
+ if (decorated)
+ {
+ tensorHandle = decorated;
+ }
+ }
workingMemDescriptor.m_Inputs.push_back(tensorHandle);
// Odd case where a constant layer is connected to an output layer
@@ -2113,6 +2121,14 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network
HandleInfo& handleInfo = outputToHandleInfoMap.at(outputSlot);
ITensorHandle* inputTensorHandle = handleInfo.m_TensorHandle;
+ if (slot.IsTensorInfoOverridden())
+ {
+ ITensorHandle* decorated = inputTensorHandle->DecorateTensorHandle(slot.GetTensorInfo()).get();
+ if (decorated)
+ {
+ inputTensorHandle = decorated;
+ }
+ }
workingMemDescriptor.m_Inputs.push_back(inputTensorHandle);
// Store the LayerBindingId of the OutputLayer