diff options
Diffstat (limited to 'src/backends/cl/ClBackendContext.cpp')
-rw-r--r-- | src/backends/cl/ClBackendContext.cpp | 28 |
1 files changed, 26 insertions, 2 deletions
diff --git a/src/backends/cl/ClBackendContext.cpp b/src/backends/cl/ClBackendContext.cpp index 125f01b627..9c5cca9d3a 100644 --- a/src/backends/cl/ClBackendContext.cpp +++ b/src/backends/cl/ClBackendContext.cpp @@ -21,8 +21,9 @@ namespace armnn struct ClBackendContext::ClContextControlWrapper { ClContextControlWrapper(arm_compute::CLTuner* tuner, + arm_compute::CLGEMMHeuristicsHandle* heuristicsHandle, bool profilingEnabled) - : m_ClContextControl(tuner, profilingEnabled) + : m_ClContextControl(tuner, heuristicsHandle, profilingEnabled) {} bool Sync() @@ -143,6 +144,7 @@ ClBackendContext::ClBackendContext(const IRuntime::CreationOptions& options) bool kernelProfiling = options.m_EnableGpuProfiling; arm_compute::CLTuner* tuner = nullptr; + arm_compute::CLGEMMHeuristicsHandle* mlgoTuner = nullptr; bool useLegacyTunerAPI = options.m_GpuAccTunedParameters.get() != nullptr; if (useLegacyTunerAPI) { @@ -197,6 +199,10 @@ ClBackendContext::ClBackendContext(const IRuntime::CreationOptions& options) { tuningLevel = ParseTuningLevel(value, defaultTuningLevel); } + else if (name == "MLGOTuningFilePath") + { + m_MLGOTuningFile = ParseFile(value, ""); + } }); // Create the tuner, in tuning mode initially. @@ -216,13 +222,31 @@ ClBackendContext::ClBackendContext(const IRuntime::CreationOptions& options) ARMNN_LOG(warning) << "Could not load GpuAcc tuner data file."; } } + + if (!m_MLGOTuningFile.empty()) + { + try + { + ARMNN_LOG(info) << "Loading Gpu MLGO tuning data from file: " << m_TuningFile; + if(m_MLGOTuner.reload_from_file(m_MLGOTuningFile.c_str())) + { + mlgoTuner = &m_MLGOTuner; + } + } + catch (const std::exception& e) + { + ARMNN_LOG(warning) << "Could not load GpuAcc MLGO tuner data file."; + } + } + tuner = m_Tuner.get(); } m_ClContextControlWrapper = std::make_unique<ClContextControlWrapper>( tuner, + mlgoTuner, kernelProfiling - ); + ); } bool ClBackendContext::BeforeLoadNetwork(NetworkId) |