aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r--src/armnn/LoadedNetwork.cpp96
1 files changed, 91 insertions, 5 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index ec79d5da3e..a27add921e 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -84,6 +84,87 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils
} // anonymous
+/**
+ * This function performs a sanity check to ensure that the combination of input and output memory source matches the
+ * values for importEnabled and exportEnabled that were specified during optimization. During optimization the tensor
+ * handle factories are chosen based on whether import and export are enabled. If the user then specifies something
+ * incompatible here it can lead to problems.
+ *
+ * @param optimizedOptions
+ * @param networkProperties
+ */
+void ValidateSourcesMatchOptimizedNetwork(std::vector<BackendOptions> optimizedOptions,
+ const INetworkProperties& networkProperties)
+{
+ // Find the "Global" backend options. During the optimize phase the values of importEnabled and exportEnabled are
+ // added as backend options.
+ const vector<BackendOptions>::iterator& backendItr =
+ find_if(optimizedOptions.begin(), optimizedOptions.end(), [](const BackendOptions& backend) {
+ if (backend.GetBackendId().Get() == "Global")
+ {
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ });
+ bool importEnabled = false;
+ bool exportEnabled = false;
+ if (backendItr != optimizedOptions.end())
+ {
+ // Find the importEnabled and exportEnabled values.
+ for (size_t i = 0; i < backendItr->GetOptionCount(); i++)
+ {
+ const BackendOptions::BackendOption& option = backendItr->GetOption(i);
+ if (option.GetName() == "ImportEnabled")
+ {
+ importEnabled = option.GetValue().AsBool();
+ }
+ if (option.GetName() == "ExportEnabled")
+ {
+ exportEnabled = option.GetValue().AsBool();
+ }
+ }
+ }
+
+ // Now that we have values for import and export compare them to the MemorySource variables.
+ // Any value of MemorySource that's not "Undefined" implies that we need to do an import of some kind.
+ if ((networkProperties.m_InputSource == MemorySource::Undefined && importEnabled) ||
+ (networkProperties.m_InputSource != MemorySource::Undefined && !importEnabled))
+ {
+ auto message = fmt::format("The input memory source specified, '{0}',", networkProperties.m_InputSource);
+ if (!importEnabled)
+ {
+ message.append(" requires that memory import be enabled. However, "
+ "it was disabled when this network was optimized.");
+ }
+ else
+ {
+ message.append(" requires that memory import be disabled. However, "
+ "it was enabled when this network was optimized.");
+ }
+ throw InvalidArgumentException(message);
+ }
+
+ if ((networkProperties.m_OutputSource == MemorySource::Undefined && exportEnabled) ||
+ (networkProperties.m_OutputSource != MemorySource::Undefined && !exportEnabled))
+ {
+ auto message = fmt::format("The output memory source specified, '{0}',", networkProperties.m_OutputSource);
+ if (!exportEnabled)
+ {
+ message.append(" requires that memory export be enabled. However, "
+ "it was disabled when this network was optimized.");
+ }
+ else
+ {
+ message.append(" requires that memory export be disabled. However, "
+ "it was enabled when this network was optimized.");
+ }
+ throw InvalidArgumentException(message);
+ }
+} // anonymous
+
std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
const INetworkProperties& networkProperties,
@@ -136,6 +217,11 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
profiler->EnableNetworkDetailsToStdOut(networkProperties.m_OutputNetworkDetailsMethod);
+ // We need to check that the memory sources match up with the values of import and export specified during the
+ // optimize phase. If they don't this will throw an exception.
+ ValidateSourcesMatchOptimizedNetwork(m_OptimizedNetwork.get()->pOptimizedNetworkImpl->GetModelOptions(),
+ m_NetworkProperties);
+
//First create tensor handlers, backends and workload factories.
//Handlers are created before workloads are.
//Because workload creation can modify some of the handlers,
@@ -1439,7 +1525,7 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp
ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get();
- if (!CheckFlag(tensorHandle->GetImportFlags(), m_NetworkProperties.m_InputSource))
+ if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource))
{
throw MemoryImportException(
fmt::format("ImportInputs: Memory Import failed, backend: "
@@ -1451,7 +1537,7 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp
std::make_unique<ConstPassthroughTensorHandle>(inputTensor.second.GetInfo(),
inputTensor.second.GetMemoryArea());
- if (tensorHandle->Import(passThroughTensorHandle->Map(), m_NetworkProperties.m_InputSource))
+ if (tensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource))
{
importedInputs.push_back(m_CurImportedInputId++);
passThroughTensorHandle->Unmap();
@@ -1564,14 +1650,14 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors&
ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get();
- if (!CheckFlag(tensorHandle->GetImportFlags(), m_NetworkProperties.m_OutputSource))
+ if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource))
{
throw MemoryImportException(fmt::format("ImportInputs: Memory Import failed, backend: "
"{} does not support importing from source {}"
- , factoryId, m_NetworkProperties.m_OutputSource));
+ , factoryId, forceImportMemorySource));
}
- if (tensorHandle->Import(outputTensor.second.GetMemoryArea(), m_NetworkProperties.m_OutputSource))
+ if (tensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
{
importedOutputs.push_back(m_CurImportedOutputId++);
}