aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r--src/armnn/LoadedNetwork.cpp96
1 files changed, 5 insertions, 91 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index a27add921e..ec79d5da3e 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -84,87 +84,6 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils
} // anonymous
-/**
- * This function performs a sanity check to ensure that the combination of input and output memory source matches the
- * values for importEnabled and exportEnabled that were specified during optimization. During optimization the tensor
- * handle factories are chosen based on whether import and export are enabled. If the user then specifies something
- * incompatible here it can lead to problems.
- *
- * @param optimizedOptions
- * @param networkProperties
- */
-void ValidateSourcesMatchOptimizedNetwork(std::vector<BackendOptions> optimizedOptions,
- const INetworkProperties& networkProperties)
-{
- // Find the "Global" backend options. During the optimize phase the values of importEnabled and exportEnabled are
- // added as backend options.
- const vector<BackendOptions>::iterator& backendItr =
- find_if(optimizedOptions.begin(), optimizedOptions.end(), [](const BackendOptions& backend) {
- if (backend.GetBackendId().Get() == "Global")
- {
- return true;
- }
- else
- {
- return false;
- }
- });
- bool importEnabled = false;
- bool exportEnabled = false;
- if (backendItr != optimizedOptions.end())
- {
- // Find the importEnabled and exportEnabled values.
- for (size_t i = 0; i < backendItr->GetOptionCount(); i++)
- {
- const BackendOptions::BackendOption& option = backendItr->GetOption(i);
- if (option.GetName() == "ImportEnabled")
- {
- importEnabled = option.GetValue().AsBool();
- }
- if (option.GetName() == "ExportEnabled")
- {
- exportEnabled = option.GetValue().AsBool();
- }
- }
- }
-
- // Now that we have values for import and export compare them to the MemorySource variables.
- // Any value of MemorySource that's not "Undefined" implies that we need to do an import of some kind.
- if ((networkProperties.m_InputSource == MemorySource::Undefined && importEnabled) ||
- (networkProperties.m_InputSource != MemorySource::Undefined && !importEnabled))
- {
- auto message = fmt::format("The input memory source specified, '{0}',", networkProperties.m_InputSource);
- if (!importEnabled)
- {
- message.append(" requires that memory import be enabled. However, "
- "it was disabled when this network was optimized.");
- }
- else
- {
- message.append(" requires that memory import be disabled. However, "
- "it was enabled when this network was optimized.");
- }
- throw InvalidArgumentException(message);
- }
-
- if ((networkProperties.m_OutputSource == MemorySource::Undefined && exportEnabled) ||
- (networkProperties.m_OutputSource != MemorySource::Undefined && !exportEnabled))
- {
- auto message = fmt::format("The output memory source specified, '{0}',", networkProperties.m_OutputSource);
- if (!exportEnabled)
- {
- message.append(" requires that memory export be enabled. However, "
- "it was disabled when this network was optimized.");
- }
- else
- {
- message.append(" requires that memory export be disabled. However, "
- "it was enabled when this network was optimized.");
- }
- throw InvalidArgumentException(message);
- }
-} // anonymous
-
std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
const INetworkProperties& networkProperties,
@@ -217,11 +136,6 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
profiler->EnableNetworkDetailsToStdOut(networkProperties.m_OutputNetworkDetailsMethod);
- // We need to check that the memory sources match up with the values of import and export specified during the
- // optimize phase. If they don't this will throw an exception.
- ValidateSourcesMatchOptimizedNetwork(m_OptimizedNetwork.get()->pOptimizedNetworkImpl->GetModelOptions(),
- m_NetworkProperties);
-
//First create tensor handlers, backends and workload factories.
//Handlers are created before workloads are.
//Because workload creation can modify some of the handlers,
@@ -1525,7 +1439,7 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp
ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get();
- if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource))
+ if (!CheckFlag(tensorHandle->GetImportFlags(), m_NetworkProperties.m_InputSource))
{
throw MemoryImportException(
fmt::format("ImportInputs: Memory Import failed, backend: "
@@ -1537,7 +1451,7 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp
std::make_unique<ConstPassthroughTensorHandle>(inputTensor.second.GetInfo(),
inputTensor.second.GetMemoryArea());
- if (tensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource))
+ if (tensorHandle->Import(passThroughTensorHandle->Map(), m_NetworkProperties.m_InputSource))
{
importedInputs.push_back(m_CurImportedInputId++);
passThroughTensorHandle->Unmap();
@@ -1650,14 +1564,14 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors&
ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get();
- if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource))
+ if (!CheckFlag(tensorHandle->GetImportFlags(), m_NetworkProperties.m_OutputSource))
{
throw MemoryImportException(fmt::format("ImportInputs: Memory Import failed, backend: "
"{} does not support importing from source {}"
- , factoryId, forceImportMemorySource));
+ , factoryId, m_NetworkProperties.m_OutputSource));
}
- if (tensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
+ if (tensorHandle->Import(outputTensor.second.GetMemoryArea(), m_NetworkProperties.m_OutputSource))
{
importedOutputs.push_back(m_CurImportedOutputId++);
}