aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/CLScheduler.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/CLScheduler.cpp')
-rw-r--r--src/runtime/CL/CLScheduler.cpp28
1 files changed, 17 insertions, 11 deletions
diff --git a/src/runtime/CL/CLScheduler.cpp b/src/runtime/CL/CLScheduler.cpp
index 6fc7baed63..ef5cb03b32 100644
--- a/src/runtime/CL/CLScheduler.cpp
+++ b/src/runtime/CL/CLScheduler.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -49,6 +49,11 @@ GPUTarget CLScheduler::target() const
return _target;
}
+CLGEMMHeuristicsHandle *CLScheduler::gemm_heuristics() const
+{
+ return _gemm_heuristics;
+}
+
void CLScheduler::set_queue(cl::CommandQueue queue)
{
_queue = std::move(queue);
@@ -92,7 +97,7 @@ bool CLScheduler::is_initialised() const
std::once_flag CLScheduler::_initialize_symbols;
CLScheduler::CLScheduler()
- : _context(), _queue(), _target(GPUTarget::MIDGARD), _is_initialised(false), _cl_tuner(nullptr), _cl_default_static_tuner(nullptr)
+ : _context(), _queue(), _target(GPUTarget::MIDGARD), _is_initialised(false), _cl_tuner(nullptr), _cl_default_static_tuner(nullptr), _gemm_heuristics(nullptr)
{
}
@@ -103,20 +108,20 @@ CLScheduler &CLScheduler::get()
return scheduler;
}
-void CLScheduler::default_init_with_context(cl::Device &device, cl::Context &ctx, ICLTuner *cl_tuner)
+void CLScheduler::default_init_with_context(cl::Device &device, cl::Context &ctx, ICLTuner *cl_tuner, CLGEMMHeuristicsHandle *gemm_h)
{
if(!_is_initialised)
{
const std::string cl_kernels_folder("./cl_kernels/");
cl::CommandQueue queue = cl::CommandQueue(ctx, device);
CLKernelLibrary::get().init(cl_kernels_folder, ctx, device);
- init(ctx, queue, device, cl_tuner);
+ init(ctx, queue, device, cl_tuner, gemm_h);
_cl_default_static_tuner = tuners::TunerFactory::create_tuner(_target);
_cl_tuner = (cl_tuner == nullptr) ? _cl_default_static_tuner.get() : cl_tuner;
}
}
-void CLScheduler::default_init(ICLTuner *cl_tuner)
+void CLScheduler::default_init(ICLTuner *cl_tuner, CLGEMMHeuristicsHandle *gemm_h)
{
if(!_is_initialised)
{
@@ -127,7 +132,7 @@ void CLScheduler::default_init(ICLTuner *cl_tuner)
ARM_COMPUTE_ERROR_ON_MSG(err != CL_SUCCESS, "Failed to create OpenCL context");
cl::CommandQueue queue = cl::CommandQueue(ctx, dev);
CLKernelLibrary::get().init("./cl_kernels/", ctx, dev);
- init(ctx, queue, dev, cl_tuner);
+ init(ctx, queue, dev, cl_tuner, gemm_h);
// Create a default static tuner and set if none was provided
_cl_default_static_tuner = tuners::TunerFactory::create_tuner(_target);
}
@@ -142,13 +147,14 @@ void CLScheduler::set_context(cl::Context context)
CLKernelLibrary::get().set_context(_context);
}
-void CLScheduler::init(cl::Context context, cl::CommandQueue queue, const cl::Device &device, ICLTuner *cl_tuner)
+void CLScheduler::init(cl::Context context, cl::CommandQueue queue, const cl::Device &device, ICLTuner *cl_tuner, CLGEMMHeuristicsHandle *gemm_h)
{
set_context(std::move(context));
- _queue = std::move(queue);
- _target = get_target_from_device(device);
- _is_initialised = true;
- _cl_tuner = cl_tuner;
+ _queue = std::move(queue);
+ _target = get_target_from_device(device);
+ _is_initialised = true;
+ _cl_tuner = cl_tuner;
+ _gemm_heuristics = gemm_h;
}
void CLScheduler::enqueue_common(ICLKernel &kernel, ITensorPack &tensors, bool flush)