diff options
Diffstat (limited to 'src/graph/backends/CL/CLDeviceBackend.cpp')
-rw-r--r-- | src/graph/backends/CL/CLDeviceBackend.cpp | 65 |
1 files changed, 38 insertions, 27 deletions
diff --git a/src/graph/backends/CL/CLDeviceBackend.cpp b/src/graph/backends/CL/CLDeviceBackend.cpp index 0159592af6..e27a4109d1 100644 --- a/src/graph/backends/CL/CLDeviceBackend.cpp +++ b/src/graph/backends/CL/CLDeviceBackend.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,19 +23,17 @@ */ #include "arm_compute/graph/backends/CL/CLDeviceBackend.h" -#include "arm_compute/graph/Graph.h" -#include "arm_compute/graph/GraphContext.h" -#include "arm_compute/graph/INode.h" -#include "arm_compute/graph/Logger.h" -#include "arm_compute/graph/Tensor.h" +#include "arm_compute/core/TensorInfo.h" #include "arm_compute/graph/backends/BackendRegistrar.h" #include "arm_compute/graph/backends/CL/CLFunctionFactory.h" #include "arm_compute/graph/backends/CL/CLNodeValidator.h" #include "arm_compute/graph/backends/CL/CLSubTensorHandle.h" #include "arm_compute/graph/backends/CL/CLTensorHandle.h" - -#include "arm_compute/core/CL/CLCoreRuntimeContext.h" -#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/graph/Graph.h" +#include "arm_compute/graph/GraphContext.h" +#include "arm_compute/graph/INode.h" +#include "arm_compute/graph/Logger.h" +#include "arm_compute/graph/Tensor.h" #include "arm_compute/runtime/BlobLifetimeManager.h" #include "arm_compute/runtime/CL/CLBufferAllocator.h" #include "arm_compute/runtime/CL/CLScheduler.h" @@ -65,17 +63,18 @@ bool file_exists(const std::string &filename) static detail::BackendRegistrar<CLDeviceBackend> CLDeviceBackend_registrar(Target::CL); CLDeviceBackend::CLDeviceBackend() - : _context_count(0), _tuner(), _allocator(nullptr), _tuner_file() + : _context_count(0), + _tuner(), + _gemm_heuristics(), + _allocator(nullptr), + _tuner_file(), + _backend_type(CLBackendType::Native) { } CLDeviceBackend::~CLDeviceBackend() { - // TODO (geopin01) : Shouldn't call non exception safe stuff here - if(_tuner.tune_new_kernels() && !_tuner.lws_table().empty() && !_tuner_file.empty()) - { - _tuner.save_to_file(_tuner_file); - } + _tuner.save_to_file(_tuner_file); } void CLDeviceBackend::set_kernel_tuning(bool enable_tuning) @@ -91,16 +90,16 @@ void CLDeviceBackend::set_kernel_tuning_mode(CLTunerMode tuning_mode) void CLDeviceBackend::initialize_backend() { // Setup Scheduler - CLScheduler::get().default_init(&_tuner); + CLScheduler::get().default_init(&_tuner, &_gemm_heuristics, _backend_type); // Create allocator with new context - _allocator = support::cpp14::make_unique<CLBufferAllocator>(nullptr /* legacy path for CLCoreRuntimeContext */); + _allocator = std::make_unique<CLBufferAllocator>(); } void CLDeviceBackend::release_backend_context(GraphContext &ctx) { ARM_COMPUTE_UNUSED(ctx); _context_count--; - if(_context_count == 0) // No more context using the backend: free resources + if (_context_count == 0) // No more context using the backend: free resources { _allocator = nullptr; } @@ -110,15 +109,17 @@ void CLDeviceBackend::setup_backend_context(GraphContext &ctx) { // Force backend initialization _context_count++; - if(_context_count == 1) + if (_context_count == 1) { + _backend_type = ctx.config().backend_type; initialize_backend(); } // Setup tuner _tuner_file = ctx.config().tuner_file; + // Load tuner data if available - if(file_exists(_tuner_file)) + if (file_exists(_tuner_file)) { _tuner.load_from_file(_tuner_file); } @@ -126,8 +127,12 @@ void CLDeviceBackend::setup_backend_context(GraphContext &ctx) set_kernel_tuning(ctx.config().use_tuner); set_kernel_tuning_mode(ctx.config().tuner_mode); + // Attempt to load mlgo heuristics + ARM_COMPUTE_ERROR_ON(CLScheduler::get().gemm_heuristics() == nullptr); + CLScheduler::get().gemm_heuristics()->reload_from_file(ctx.config().mlgo_file); + // Setup a management backend - if(ctx.memory_management_ctx(Target::CL) == nullptr) + if (ctx.memory_management_ctx(Target::CL) == nullptr) { MemoryManagerContext mm_ctx; mm_ctx.target = Target::CL; @@ -140,7 +145,7 @@ void CLDeviceBackend::setup_backend_context(GraphContext &ctx) } // Create function level weights manager - if(ctx.weights_management_ctx(Target::CL) == nullptr) + if (ctx.weights_management_ctx(Target::CL) == nullptr) { WeightsManagerContext wm_ctx; wm_ctx.target = Target::CL; @@ -170,17 +175,18 @@ std::unique_ptr<ITensorHandle> CLDeviceBackend::create_tensor(const Tensor &tens TensorInfo info(tensor_desc.shape, 1, tensor_desc.data_type, tensor_desc.quant_info); info.set_data_layout(tensor_desc.layout); - return support::cpp14::make_unique<CLTensorHandle>(info); + return std::make_unique<CLTensorHandle>(info); } -std::unique_ptr<ITensorHandle> CLDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent) +std::unique_ptr<ITensorHandle> +CLDeviceBackend::create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent) { - if(parent == nullptr) + if (parent == nullptr) { return nullptr; } - return support::cpp14::make_unique<CLSubTensorHandle>(parent, shape, coords, extend_parent); + return std::make_unique<CLSubTensorHandle>(parent, shape, coords, extend_parent); } std::unique_ptr<arm_compute::IFunction> CLDeviceBackend::configure_node(INode &node, GraphContext &ctx) @@ -202,7 +208,7 @@ arm_compute::Status CLDeviceBackend::validate_node(INode &node) std::shared_ptr<arm_compute::IMemoryManager> CLDeviceBackend::create_memory_manager(MemoryManagerAffinity affinity) { - if(affinity == MemoryManagerAffinity::Offset) + if (affinity == MemoryManagerAffinity::Offset) { ARM_COMPUTE_LOG_GRAPH_WARNING("CL Backend does not support offset affinity memory management!"); return nullptr; @@ -220,6 +226,11 @@ std::shared_ptr<arm_compute::IWeightsManager> CLDeviceBackend::create_weights_ma auto weights_mgr = std::make_shared<IWeightsManager>(); return weights_mgr; } + +void CLDeviceBackend::sync() +{ + CLScheduler::get().sync(); +} } // namespace backends } // namespace graph } // namespace arm_compute |