From 626bd90378670eb5fd76f94526395430b752ad9e Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Tue, 21 Jun 2022 13:16:23 +0000 Subject: Revert "Revert "IVGCVSW-6873 Import inputs but don't export outputs fails."" This reverts commit a0f8b15d4ddb5075f380003ff31b271d389d3b66. Reason for revert: Change-Id: Ibc4a77fa008643849da7330391942e4c87b941e2 --- src/armnn/LoadedNetwork.cpp | 96 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 91 insertions(+), 5 deletions(-) (limited to 'src/armnn/LoadedNetwork.cpp') diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index ec79d5da3e..a27add921e 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -84,6 +84,87 @@ void AddWorkloadStructure(std::unique_ptr& timelineUtils } // anonymous +/** + * This function performs a sanity check to ensure that the combination of input and output memory source matches the + * values for importEnabled and exportEnabled that were specified during optimization. During optimization the tensor + * handle factories are chosen based on whether import and export are enabled. If the user then specifies something + * incompatible here it can lead to problems. + * + * @param optimizedOptions + * @param networkProperties + */ +void ValidateSourcesMatchOptimizedNetwork(std::vector optimizedOptions, + const INetworkProperties& networkProperties) +{ + // Find the "Global" backend options. During the optimize phase the values of importEnabled and exportEnabled are + // added as backend options. + const vector::iterator& backendItr = + find_if(optimizedOptions.begin(), optimizedOptions.end(), [](const BackendOptions& backend) { + if (backend.GetBackendId().Get() == "Global") + { + return true; + } + else + { + return false; + } + }); + bool importEnabled = false; + bool exportEnabled = false; + if (backendItr != optimizedOptions.end()) + { + // Find the importEnabled and exportEnabled values. + for (size_t i = 0; i < backendItr->GetOptionCount(); i++) + { + const BackendOptions::BackendOption& option = backendItr->GetOption(i); + if (option.GetName() == "ImportEnabled") + { + importEnabled = option.GetValue().AsBool(); + } + if (option.GetName() == "ExportEnabled") + { + exportEnabled = option.GetValue().AsBool(); + } + } + } + + // Now that we have values for import and export compare them to the MemorySource variables. + // Any value of MemorySource that's not "Undefined" implies that we need to do an import of some kind. + if ((networkProperties.m_InputSource == MemorySource::Undefined && importEnabled) || + (networkProperties.m_InputSource != MemorySource::Undefined && !importEnabled)) + { + auto message = fmt::format("The input memory source specified, '{0}',", networkProperties.m_InputSource); + if (!importEnabled) + { + message.append(" requires that memory import be enabled. However, " + "it was disabled when this network was optimized."); + } + else + { + message.append(" requires that memory import be disabled. However, " + "it was enabled when this network was optimized."); + } + throw InvalidArgumentException(message); + } + + if ((networkProperties.m_OutputSource == MemorySource::Undefined && exportEnabled) || + (networkProperties.m_OutputSource != MemorySource::Undefined && !exportEnabled)) + { + auto message = fmt::format("The output memory source specified, '{0}',", networkProperties.m_OutputSource); + if (!exportEnabled) + { + message.append(" requires that memory export be enabled. However, " + "it was disabled when this network was optimized."); + } + else + { + message.append(" requires that memory export be disabled. However, " + "it was enabled when this network was optimized."); + } + throw InvalidArgumentException(message); + } +} // anonymous + std::unique_ptr LoadedNetwork::MakeLoadedNetwork(std::unique_ptr net, std::string& errorMessage, const INetworkProperties& networkProperties, @@ -136,6 +217,11 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr net, profiler->EnableNetworkDetailsToStdOut(networkProperties.m_OutputNetworkDetailsMethod); + // We need to check that the memory sources match up with the values of import and export specified during the + // optimize phase. If they don't this will throw an exception. + ValidateSourcesMatchOptimizedNetwork(m_OptimizedNetwork.get()->pOptimizedNetworkImpl->GetModelOptions(), + m_NetworkProperties); + //First create tensor handlers, backends and workload factories. //Handlers are created before workloads are. //Because workload creation can modify some of the handlers, @@ -1439,7 +1525,7 @@ std::vector LoadedNetwork::ImportInputs(const InputTensors& inp ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get(); - if (!CheckFlag(tensorHandle->GetImportFlags(), m_NetworkProperties.m_InputSource)) + if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource)) { throw MemoryImportException( fmt::format("ImportInputs: Memory Import failed, backend: " @@ -1451,7 +1537,7 @@ std::vector LoadedNetwork::ImportInputs(const InputTensors& inp std::make_unique(inputTensor.second.GetInfo(), inputTensor.second.GetMemoryArea()); - if (tensorHandle->Import(passThroughTensorHandle->Map(), m_NetworkProperties.m_InputSource)) + if (tensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource)) { importedInputs.push_back(m_CurImportedInputId++); passThroughTensorHandle->Unmap(); @@ -1564,14 +1650,14 @@ std::vector LoadedNetwork::ImportOutputs(const OutputTensors& ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get(); - if (!CheckFlag(tensorHandle->GetImportFlags(), m_NetworkProperties.m_OutputSource)) + if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource)) { throw MemoryImportException(fmt::format("ImportInputs: Memory Import failed, backend: " "{} does not support importing from source {}" - , factoryId, m_NetworkProperties.m_OutputSource)); + , factoryId, forceImportMemorySource)); } - if (tensorHandle->Import(outputTensor.second.GetMemoryArea(), m_NetworkProperties.m_OutputSource)) + if (tensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource)) { importedOutputs.push_back(m_CurImportedOutputId++); } -- cgit v1.2.1