diff options
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 3f4aa34a5b..3d84054b69 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -955,10 +955,10 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, syncDesc.m_Inputs.push_back(inputTensorHandle); WorkloadInfo info; info.m_InputTensorInfos.push_back( - outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo()); + outputLayer->GetInputSlot(0).GetTensorInfo()); auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info); ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created"); - m_OutputQueue.push_back(move(syncWorkload)); + m_OutputQueue.push_back(std::move(syncWorkload)); importedOutputIdIndex++; } else @@ -1089,7 +1089,7 @@ void LoadedNetwork::EnqueueInput(const BindableLayer& layer, ITensorHandle* tens timelineUtils->Commit(); } - m_InputQueue.push_back(move(inputWorkload)); + m_InputQueue.push_back(std::move(inputWorkload)); } } @@ -1149,7 +1149,7 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten info.m_InputTensorInfos.push_back(inputTensorInfo); auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info); ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created"); - m_OutputQueue.push_back(move(syncWorkload)); + m_OutputQueue.push_back(std::move(syncWorkload)); } else { @@ -1177,7 +1177,7 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten timelineUtils->Commit(); } - m_OutputQueue.push_back(move(outputWorkload)); + m_OutputQueue.push_back(std::move(outputWorkload)); } } @@ -1650,7 +1650,7 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors& const InputSlot& inputSlot = layer->GetInputSlots()[0]; ITensorHandleFactory::FactoryId factoryId = inputSlot.GetConnectedOutputSlot()->GetTensorHandleFactoryId(); - const TensorInfo& tensorInfo = inputSlot.GetConnectedOutputSlot()->GetTensorInfo(); + const TensorInfo& tensorInfo = inputSlot.GetTensorInfo(); ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId); ARMNN_ASSERT(handleFactory); @@ -2093,6 +2093,14 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network if (found != m_ConstantTensorHandles.end()) { ITensorHandle* tensorHandle = found->second; + if (slot.IsTensorInfoOverridden()) + { + ITensorHandle* decorated = tensorHandle->DecorateTensorHandle(slot.GetTensorInfo()).get(); + if (decorated) + { + tensorHandle = decorated; + } + } workingMemDescriptor.m_Inputs.push_back(tensorHandle); // Odd case where a constant layer is connected to an output layer @@ -2113,6 +2121,14 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network HandleInfo& handleInfo = outputToHandleInfoMap.at(outputSlot); ITensorHandle* inputTensorHandle = handleInfo.m_TensorHandle; + if (slot.IsTensorInfoOverridden()) + { + ITensorHandle* decorated = inputTensorHandle->DecorateTensorHandle(slot.GetTensorInfo()).get(); + if (decorated) + { + inputTensorHandle = decorated; + } + } workingMemDescriptor.m_Inputs.push_back(inputTensorHandle); // Store the LayerBindingId of the OutputLayer |