diff options
Diffstat (limited to 'delegate/src/armnn_delegate.cpp')
-rw-r--r-- | delegate/src/armnn_delegate.cpp | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/delegate/src/armnn_delegate.cpp b/delegate/src/armnn_delegate.cpp index 6250a5f638..6dba890509 100644 --- a/delegate/src/armnn_delegate.cpp +++ b/delegate/src/armnn_delegate.cpp @@ -134,6 +134,10 @@ Delegate::Delegate(armnnDelegate::DelegateOptions options) { runtimeOptions.m_BackendOptions = backendOptions; } + else if (!m_Options.GetOptimizerOptions().m_ModelOptions.empty()) + { + runtimeOptions.m_BackendOptions = m_Options.GetOptimizerOptions().m_ModelOptions; + } m_Runtime = armnn::IRuntime::Create(runtimeOptions); std::vector<armnn::BackendId> backends; @@ -288,7 +292,6 @@ ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext, delegateData.m_OutputSlotForNode = std::vector<armnn::IOutputSlot*>(tfLiteContext->tensors_size, nullptr); - std::vector<armnn::BindingPointInfo> inputBindings; std::vector<armnn::BindingPointInfo> outputBindings; @@ -331,7 +334,8 @@ ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext, { optNet = armnn::Optimize(*(delegateData.m_Network.get()), delegate->m_Options.GetBackends(), - delegate->m_Runtime->GetDeviceSpec()); + delegate->m_Runtime->GetDeviceSpec(), + delegate->m_Options.GetOptimizerOptions()); } catch (std::exception &ex) { @@ -348,11 +352,15 @@ ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext, try { // Load graph into runtime - auto loadingStatus = delegate->m_Runtime->LoadNetwork(networkId, std::move(optNet)); + std::string errorMessage; + auto loadingStatus = delegate->m_Runtime->LoadNetwork(networkId, + std::move(optNet), + errorMessage, + delegate->m_Options.GetNetworkProperties()); if (loadingStatus != armnn::Status::Success) { // Optimize failed - throw armnn::Exception("TfLiteArmnnDelegate: Network could not be loaded!");; + throw armnn::Exception("TfLiteArmnnDelegate: Network could not be loaded:" + errorMessage); } } catch (std::exception& ex) @@ -362,6 +370,12 @@ ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext, throw armnn::Exception(exMessage.str()); } + // Register debug callback function + if (delegate->m_Options.GetDebugCallbackFunction().has_value()) + { + delegate->m_Runtime->RegisterDebugCallback(networkId, delegate->m_Options.GetDebugCallbackFunction().value()); + } + // Create a new SubGraph with networkId and runtime return new ArmnnSubgraph(networkId, delegate->m_Runtime.get(), inputBindings, outputBindings); } |