diff options
Diffstat (limited to 'src/runtime/CL/functions/CLLocallyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLLocallyConnectedLayer.cpp | 35 |
1 files changed, 24 insertions, 11 deletions
diff --git a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp index 74cb47347f..04e59ac4a6 100644 --- a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp @@ -27,6 +27,11 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/runtime/CL/CLScheduler.h" +#include "src/core/CL/kernels/CLCol2ImKernel.h" +#include "src/core/CL/kernels/CLIm2ColKernel.h" +#include "src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.h" +#include "src/core/CL/kernels/CLWeightsReshapeKernel.h" +#include "support/MemorySupport.h" #include <cmath> #include <tuple> @@ -78,8 +83,16 @@ void calculate_shapes(const ITensorInfo *input, const ITensorInfo *weights, cons } // namespace CLLocallyConnectedLayer::CLLocallyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager) - : _memory_group(std::move(memory_manager)), _input_im2col_kernel(), _weights_reshape_kernel(), _mm_kernel(), _output_col2im_kernel(), _input_im2col_reshaped(), _weights_reshaped(), _gemm_output(), - _is_prepared(false), _original_weights(nullptr) + : _memory_group(std::move(memory_manager)), + _input_im2col_kernel(support::cpp14::make_unique<CLIm2ColKernel>()), + _weights_reshape_kernel(support::cpp14::make_unique<CLWeightsReshapeKernel>()), + _mm_kernel(support::cpp14::make_unique<CLLocallyConnectedMatrixMultiplyKernel>()), + _output_col2im_kernel(support::cpp14::make_unique<CLCol2ImKernel>()), + _input_im2col_reshaped(), + _weights_reshaped(), + _gemm_output(), + _is_prepared(false), + _original_weights(nullptr) { } @@ -169,16 +182,16 @@ void CLLocallyConnectedLayer::configure(const CLCompileContext &compile_context, _memory_group.manage(&_gemm_output); // Configure kernels - _input_im2col_kernel.configure(compile_context, input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _has_bias); - _weights_reshape_kernel.configure(compile_context, weights, biases, &_weights_reshaped); - _mm_kernel.configure(compile_context, &_input_im2col_reshaped, &_weights_reshaped, &_gemm_output); - _output_col2im_kernel.configure(compile_context, &_gemm_output, output, Size2D(conv_w, conv_h)); + _input_im2col_kernel->configure(compile_context, input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _has_bias); + _weights_reshape_kernel->configure(compile_context, weights, biases, &_weights_reshaped); + _mm_kernel->configure(compile_context, &_input_im2col_reshaped, &_weights_reshaped, &_gemm_output); + _output_col2im_kernel->configure(compile_context, &_gemm_output, output, Size2D(conv_w, conv_h)); // Allocate intermediate tensors _input_im2col_reshaped.allocator()->allocate(); _gemm_output.allocator()->allocate(); - CLScheduler::get().tune_kernel_static(_input_im2col_kernel); + CLScheduler::get().tune_kernel_static(*_input_im2col_kernel); } void CLLocallyConnectedLayer::run() @@ -188,13 +201,13 @@ void CLLocallyConnectedLayer::run() MemoryGroupResourceScope scope_mg(_memory_group); // Run input reshaping - CLScheduler::get().enqueue(_input_im2col_kernel); + CLScheduler::get().enqueue(*_input_im2col_kernel); // Runs vector matrix multiply on reshaped matrices - CLScheduler::get().enqueue(_mm_kernel); + CLScheduler::get().enqueue(*_mm_kernel); // Reshape output matrix - CLScheduler::get().enqueue(_output_col2im_kernel, false); + CLScheduler::get().enqueue(*_output_col2im_kernel.get(), false); } void CLLocallyConnectedLayer::prepare() @@ -205,7 +218,7 @@ void CLLocallyConnectedLayer::prepare() // Run weights reshaping and mark original weights tensor as unused _weights_reshaped.allocator()->allocate(); - CLScheduler::get().enqueue(_weights_reshape_kernel); + CLScheduler::get().enqueue(*_weights_reshape_kernel); _original_weights->mark_as_unused(); CLScheduler::get().queue().finish(); |