From 17812ba9f7cf2c8f5121c11760ac45fbbdb7aeaf Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 4 Jun 2018 19:27:13 +0100 Subject: COMPMID-817: Tuner: Port kernels to new design. Change-Id: Iaabb1153c2abe0400ec79d51a21347debe92d642 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/134062 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- arm_compute/core/CL/kernels/CLCol2ImKernel.h | 2 +- .../core/CL/kernels/CLGEMMMatrixMultiplyKernel.h | 2 +- arm_compute/core/CL/kernels/CLIm2ColKernel.h | 3 +- arm_compute/core/CL/kernels/CLPoolingLayerKernel.h | 4 +- arm_compute/runtime/CL/tuners/MIdgardTuner.h | 43 ++++++ src/core/CL/kernels/CLCol2ImKernel.cpp | 15 --- src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp | 9 -- src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp | 48 +------ .../kernels/CLGEMMMatrixVectorMultiplyKernel.cpp | 8 -- src/core/CL/kernels/CLIm2ColKernel.cpp | 6 +- src/core/CL/kernels/CLPoolingLayerKernel.cpp | 12 +- .../CL/functions/CLDepthwiseConvolutionLayer.cpp | 2 + src/runtime/CL/functions/CLFlattenLayer.cpp | 4 +- src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 1 + src/runtime/CL/functions/CLGEMM.cpp | 2 + .../CL/functions/CLGEMMConvolutionLayer.cpp | 8 +- .../CL/functions/CLLocallyConnectedLayer.cpp | 2 + src/runtime/CL/functions/CLPoolingLayer.cpp | 3 + src/runtime/CL/tuners/BifrostTuner.cpp | 150 ++++++++++++++++++++- src/runtime/CL/tuners/MidgardTuner.cpp | 77 +++++++++++ 20 files changed, 300 insertions(+), 101 deletions(-) create mode 100644 arm_compute/runtime/CL/tuners/MIdgardTuner.h create mode 100644 src/runtime/CL/tuners/MidgardTuner.cpp diff --git a/arm_compute/core/CL/kernels/CLCol2ImKernel.h b/arm_compute/core/CL/kernels/CLCol2ImKernel.h index 24d0fdd914..3779325efe 100644 --- a/arm_compute/core/CL/kernels/CLCol2ImKernel.h +++ b/arm_compute/core/CL/kernels/CLCol2ImKernel.h @@ -86,7 +86,7 @@ public: // Inherited methods overridden: void run(const Window &window, cl::CommandQueue &queue) override; -private: +public: const ICLTensor *_input; ICLTensor *_output; std::pair _convolved_dims; diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h index ee7e7c0e97..13802b97ad 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h @@ -80,7 +80,7 @@ public: // Inherited methods overridden: void run(const Window &window, cl::CommandQueue &queue) override; -private: +public: const ICLTensor *_input0; const ICLTensor *_input1; ICLTensor *_output; diff --git a/arm_compute/core/CL/kernels/CLIm2ColKernel.h b/arm_compute/core/CL/kernels/CLIm2ColKernel.h index 43812e42a3..45111fcedd 100644 --- a/arm_compute/core/CL/kernels/CLIm2ColKernel.h +++ b/arm_compute/core/CL/kernels/CLIm2ColKernel.h @@ -113,9 +113,10 @@ private: /** Common signature for the kernel to run */ using Im2ColFunction = void (CLIm2ColKernel::*)(const Window &, cl::CommandQueue &); -private: +public: const ICLTensor *_input; ICLTensor *_output; + PadStrideInfo _conv_info; std::pair _convolved_dims; unsigned int _num_elems_processed_per_iteration; Im2ColFunction _run_func; diff --git a/arm_compute/core/CL/kernels/CLPoolingLayerKernel.h b/arm_compute/core/CL/kernels/CLPoolingLayerKernel.h index e9ce28b3f9..c13507785b 100644 --- a/arm_compute/core/CL/kernels/CLPoolingLayerKernel.h +++ b/arm_compute/core/CL/kernels/CLPoolingLayerKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -72,7 +72,7 @@ public: void run(const Window &window, cl::CommandQueue &queue) override; BorderSize border_size() const override; -private: +public: const ICLTensor *_input; ICLTensor *_output; PoolingLayerInfo _pool_info; diff --git a/arm_compute/runtime/CL/tuners/MIdgardTuner.h b/arm_compute/runtime/CL/tuners/MIdgardTuner.h new file mode 100644 index 0000000000..4aa58f41f7 --- /dev/null +++ b/arm_compute/runtime/CL/tuners/MIdgardTuner.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_TUNERS_MIDGARD_TUNER_H__ +#define __ARM_COMPUTE_TUNERS_MIDGARD_TUNER_H__ + +#include "arm_compute/runtime/CL/ICLTuner.h" + +namespace arm_compute +{ +namespace tuners +{ +/** Midgard based OpenCL tuner implementation */ +class MidgardTuner final : public ICLTuner +{ +public: + // Inherited overriden methods + void tune_kernel_static(ICLKernel &kernel) override; + void tune_kernel_dynamic(ICLKernel &kernel) override; +}; +} // namespace tuners +} // namespace arm_compute +#endif /*__ARM_COMPUTE_TUNERS_MIDGARD_TUNER_H__ */ diff --git a/src/core/CL/kernels/CLCol2ImKernel.cpp b/src/core/CL/kernels/CLCol2ImKernel.cpp index e15da7258a..4e444206f1 100644 --- a/src/core/CL/kernels/CLCol2ImKernel.cpp +++ b/src/core/CL/kernels/CLCol2ImKernel.cpp @@ -110,21 +110,6 @@ void CLCol2ImKernel::configure(const ICLTensor *input, ICLTensor *output, std::p _kernel = static_cast(CLKernelLibrary::get().create_kernel("col2im", build_opts.options())); - // Configure the local work size for Bifrost with a value obtained - // via exhaustive autotuning over 30 representative tensor shapes. - const GPUTarget gpu_target = get_target(); - if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) - { - if((_convolved_dims.first == 7) || (_convolved_dims.first == 14)) - { - _lws_hint = cl::NDRange(1, 7, 1); - } - else - { - _lws_hint = cl::NDRange(1, 8, 1); - } - } - // Configure kernel window auto win_config = validate_and_configure_window(input->info(), output->info(), _convolved_dims); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); diff --git a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp index 41ff2202ca..c89b16eedc 100644 --- a/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseIm2ColKernel.cpp @@ -90,15 +90,6 @@ void CLDepthwiseIm2ColKernel::configure(const ICLTensor *input, ICLTensor *outpu _kernel = static_cast(CLKernelLibrary::get().create_kernel("depthwise_im2col", build_opts.options())); - // Configure the local work size for Bifrost with a value obtained - // via exhaustive autotuning for the MobileNets tensor shapes. - const GPUTarget gpu_target = get_target(); - - if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) - { - _lws_hint = cl::NDRange(1, 2, 1); - } - // Configure kernel window Window win = calculate_max_window(*output->info(), Steps()); // CLDepthwiseIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index 7a9760b778..fc52f4e124 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -194,51 +194,9 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen _output = output; _slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions(); - const DataType data_type = input0->info()->data_type(); - const int fp_pos = input0->info()->fixed_point_position(); - - // Get target architecture - GPUTarget gpu_target = get_target(); - - // Configure LWS hint - switch(gpu_target) - { - case GPUTarget::MIDGARD: - case GPUTarget::T600: - case GPUTarget::T700: - case GPUTarget::T800: - if(output->info()->dimension(1) == 196) - { - _lws_hint = cl::NDRange(1, 7); - } - else - { - _lws_hint = cl::NDRange(8, 8); - } - break; - case GPUTarget::G71: - case GPUTarget::G72: - case GPUTarget::G51: - case GPUTarget::G51BIG: - case GPUTarget::G51LIT: - case GPUTarget::TNOX: - if(input1->info()->dimension(1) == 24) - { - // LWS optimized for the 11x11 AlexNet convolution on Bifrost. - _lws_hint = cl::NDRange(2, 2); - } - else if(output->info()->dimension(1) == 196) - { - _lws_hint = cl::NDRange(1, 7); - } - else - { - _lws_hint = cl::NDRange(8, 8); - } - break; - default: - _lws_hint = cl::NullRange; - } + const DataType data_type = input0->info()->data_type(); + const int fp_pos = input0->info()->fixed_point_position(); + const GPUTarget gpu_target = get_target(); ElementsProcessed num_elements_processed{}; diff --git a/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp index 1d6f388def..d8ecd501b0 100644 --- a/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.cpp @@ -110,14 +110,6 @@ void CLGEMMMatrixVectorMultiplyKernel::configure(const ICLTensor *input0, const _kernel.setArg(idx++, -_input1->info()->quantization_info().offset); } - // Configure the local work size for Bifrost with a value obtained - // via exhaustive autotuning for the MobileNets tensor shapes. - const GPUTarget gpu_target = get_target(); - if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) - { - _lws_hint = cl::NDRange(1, 1, 1); - } - // Configure kernel window const unsigned int num_elems_read_per_iteration = 4; diff --git a/src/core/CL/kernels/CLIm2ColKernel.cpp b/src/core/CL/kernels/CLIm2ColKernel.cpp index 378456cde6..53a4dca9a3 100644 --- a/src/core/CL/kernels/CLIm2ColKernel.cpp +++ b/src/core/CL/kernels/CLIm2ColKernel.cpp @@ -61,7 +61,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, b } // namespace CLIm2ColKernel::CLIm2ColKernel() - : _input(nullptr), _output(nullptr), _convolved_dims(), _num_elems_processed_per_iteration(1), _run_func(nullptr), _kernel_dims() + : _input(nullptr), _output(nullptr), _conv_info(), _convolved_dims(), _num_elems_processed_per_iteration(1), _run_func(nullptr), _kernel_dims() { } @@ -74,6 +74,7 @@ void CLIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const _input = input; _output = output; + _conv_info = conv_info; _kernel_dims = kernel_dims; const DataType data_type = input->info()->data_type(); @@ -190,10 +191,9 @@ void CLIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const { vector_size = kernel_dims.width; } - // Local work size and vector size optimized for the 11x11 AlexNet convolution on Bifrost. + // Vector size optimized for the 11x11 AlexNet convolution on Bifrost. if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && kernel_dims.width == 11) { - _lws_hint = cl::NDRange(1, 1, 1); vector_size = 8; } const size_t width_mod_vector_size = kernel_dims.width % vector_size; diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp index 3091df4665..b242c5550c 100644 --- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp +++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp @@ -208,8 +208,7 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, _output = output; _pool_info = pool_info; - const GPUTarget gpu_target = get_target(); - const DataType data_type = input->info()->data_type(); + const DataType data_type = input->info()->data_type(); // Set build options CLBuildOptions build_opts; @@ -273,20 +272,11 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config)); ICLKernel::configure(std::get<1>(win_config)); - // Configure the local work size (hint) from the first two dimensions of the global work size. - // On Bifrost, this works for up to 35x35xC filters, for which the pooling_layer_3_optimized - // kernel is launched with gws=(9, 33, C). In any case, the hint will be ignored if it is - // invalid (e.g. exceeds the maximum workgroup size that the kernel can be launched with). if(data_layout == DataLayout::NCHW) { CLPoolingConfig pooling_config = std::get<2>(win_config); _num_elems_processed_per_iteration = pooling_config.first; _border_size = pooling_config.second; - if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) - { - cl::NDRange gws = ICLKernel::gws_from_window(std::get<1>(win_config)); - _lws_hint = cl::NDRange(gws[0], gws[1], 1); - } } else { diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp index 676a121a76..c2b24e3c20 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp @@ -134,6 +134,7 @@ void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *w _input_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); _im2col_kernel.set_target(gpu_target); _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias, depth_multiplier); + CLScheduler::get().tune_kernel_static(_im2col_kernel); // Weights reshape configuration const TensorShape shape_weights_reshape(patch_size, weights_z); @@ -149,6 +150,7 @@ void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *w _v2mm_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out)); _v2mm_kernel.set_target(gpu_target); _v2mm_kernel.configure(&_input_reshaped, &_weights_reshaped, &_v2mm_output); + CLScheduler::get().tune_kernel_static(_v2mm_kernel); _output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape)); _vector_to_tensor_kernel.configure(&_v2mm_output, (_is_quantized) ? &_output_reshaped : output, conv_w, conv_h); diff --git a/src/runtime/CL/functions/CLFlattenLayer.cpp b/src/runtime/CL/functions/CLFlattenLayer.cpp index 9f571b2036..f5809a218a 100644 --- a/src/runtime/CL/functions/CLFlattenLayer.cpp +++ b/src/runtime/CL/functions/CLFlattenLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,6 +25,7 @@ #include "arm_compute/core/CL/kernels/CLIm2ColKernel.h" #include "arm_compute/core/Size2D.h" +#include "arm_compute/runtime/CL/CLScheduler.h" #include "support/ToolchainSupport.h" using namespace arm_compute; @@ -34,4 +35,5 @@ void CLFlattenLayer::configure(const ICLTensor *input, ICLTensor *output) auto k = arm_compute::support::cpp14::make_unique(); k->configure(input, output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false); _kernel = std::move(k); + CLScheduler::get().tune_kernel_static(*_kernel); } diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 44bf28374f..9248bc559b 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -117,6 +117,7 @@ void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLT // Configure im2col kernel _memory_group.manage(&_im2col_output); _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false); + CLScheduler::get().tune_kernel_static(_im2col_kernel); // Configure matrix multiply kernel configure_mm(&_im2col_output, weights, output); diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 7f37520f10..a0ec66f804 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -143,7 +143,9 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * _transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width); } + // Configure and tune matrix multiply kernel _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height)); + CLScheduler::get().tune_kernel_static(_mm_kernel); if(_is_interleaved_transposed) { diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index 4f87043373..27bed44098 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -230,10 +230,11 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * _gemm_output.allocator()->init(info_gemm); _memory_group.manage(&_gemm_output); - // Configure im2col + // Configure and tune im2col _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation); + CLScheduler::get().tune_kernel_static(_im2col_kernel); - // Configure GEMM + // Configure and tune GEMM configure_mm(&_im2col_output, weights, &_gemm_output); _im2col_output.allocator()->allocate(); @@ -250,8 +251,9 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset); } - // Configure Col2Im + // Configure and tune Col2Im _col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, std::make_pair(conv_w, conv_h)); + CLScheduler::get().tune_kernel_static(_col2im_kernel); if(_is_quantized) { _tmp_output.allocator()->allocate(); diff --git a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp index 986fe00973..31d5cd5a7e 100644 --- a/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLLocallyConnectedLayer.cpp @@ -163,6 +163,8 @@ void CLLocallyConnectedLayer::configure(const ICLTensor *input, const ICLTensor _weights_reshaped.allocator()->allocate(); _input_im2col_reshaped.allocator()->allocate(); _gemm_output.allocator()->allocate(); + + CLScheduler::get().tune_kernel_static(_input_im2col_kernel); } void CLLocallyConnectedLayer::run() diff --git a/src/runtime/CL/functions/CLPoolingLayer.cpp b/src/runtime/CL/functions/CLPoolingLayer.cpp index 17875a38ad..cbe1ce3b47 100644 --- a/src/runtime/CL/functions/CLPoolingLayer.cpp +++ b/src/runtime/CL/functions/CLPoolingLayer.cpp @@ -63,6 +63,9 @@ void CLPoolingLayer::configure(ICLTensor *input, ICLTensor *output, const Poolin ARM_COMPUTE_ERROR("Data layout not supported"); } _border_handler.configure(input, _kernel->border_size(), border_mode, pixel_value); + + // Tune kernels + CLScheduler::get().tune_kernel_static(*_kernel); } Status CLPoolingLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info) diff --git a/src/runtime/CL/tuners/BifrostTuner.cpp b/src/runtime/CL/tuners/BifrostTuner.cpp index c0ebd24afe..edd074ba08 100644 --- a/src/runtime/CL/tuners/BifrostTuner.cpp +++ b/src/runtime/CL/tuners/BifrostTuner.cpp @@ -124,15 +124,163 @@ void tune_direct_convolution_kernel(CLDirectConvolutionLayerKernel &k) k.set_lws_hint(lws_hint); } } + +void tune_col2im_kernel(CLCol2ImKernel &k) +{ + cl::NDRange lws_hint = k.lws_hint(); + const GPUTarget gpu_target = k.get_target(); + + // Configure the local work size for Bifrost with a value obtained + // via exhaustive autotuning over 30 representative tensor shapes. + if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) + { + if((k._convolved_dims.first == 7) || (k._convolved_dims.first == 14)) + { + lws_hint = cl::NDRange(1, 7, 1); + } + else + { + lws_hint = cl::NDRange(1, 8, 1); + } + } + + k.set_lws_hint(lws_hint); +} + +void tune_im2col_kernel(CLIm2ColKernel &k) +{ + cl::NDRange lws_hint = k.lws_hint(); + const GPUTarget gpu_target = k.get_target(); + + // Local work size optimized for the 11x11 AlexNet convolution on Bifrost. + if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && k._kernel_dims.width == 11) + { + const bool is_square_kernel = (k._kernel_dims.width == k._kernel_dims.height); + if(!is_square_kernel && k._kernel_dims.width > 1 && !k._conv_info.has_padding()) + { + lws_hint = cl::NDRange(1, 1, 1); + } + } + k.set_lws_hint(lws_hint); +} + +void tune_depthwise_im2col_kernel(CLDepthwiseIm2ColKernel &k) +{ + cl::NDRange lws_hint = k.lws_hint(); + const GPUTarget gpu_target = k.get_target(); + + // Configure the local work size for Bifrost with a value obtained + // via exhaustive autotuning for the MobileNets tensor shapes. + if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) + { + lws_hint = cl::NDRange(1, 2, 1); + } + + k.set_lws_hint(lws_hint); +} + +void tune_gemv_kernel(CLGEMMMatrixVectorMultiplyKernel &k) +{ + cl::NDRange lws_hint = k.lws_hint(); + const GPUTarget gpu_target = k.get_target(); + + // Configure the local work size for Bifrost with a value obtained + // via exhaustive autotuning for the MobileNets tensor shapes. + if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) + { + lws_hint = cl::NDRange(1, 1, 1); + } + + k.set_lws_hint(lws_hint); +} + +void tune_gemm_kernel(CLGEMMMatrixMultiplyKernel &k) +{ + cl::NDRange lws_hint = k.lws_hint(); + const GPUTarget gpu_target = k.get_target(); + + // Configure LWS hint + switch(gpu_target) + { + case GPUTarget::G71: + case GPUTarget::G72: + case GPUTarget::G51: + case GPUTarget::G51BIG: + case GPUTarget::G51LIT: + case GPUTarget::TNOX: + if(k._input1->info()->dimension(1) == 24) + { + // LWS optimized for the 11x11 AlexNet convolution on Bifrost. + lws_hint = cl::NDRange(2, 2); + } + else if(k._output->info()->dimension(1) == 196) + { + lws_hint = cl::NDRange(1, 7); + } + else + { + lws_hint = cl::NDRange(8, 8); + } + break; + default: + lws_hint = cl::NullRange; + } + + k.set_lws_hint(lws_hint); +} + +void tune_pooling_kernel(CLPoolingLayerKernel &k) +{ + cl::NDRange lws_hint = k.lws_hint(); + const GPUTarget gpu_target = k.get_target(); + + // Configure the local work size (hint) from the first two dimensions of the global work size. + // On Bifrost, this works for up to 35x35xC filters, for which the pooling_layer_3_optimized + // kernel is launched with gws=(9, 33, C). In any case, the hint will be ignored if it is + // invalid (e.g. exceeds the maximum workgroup size that the kernel can be launched with). + if(k._input->info()->data_layout() == DataLayout::NCHW) + { + if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX)) + { + cl::NDRange gws = ICLKernel::gws_from_window(k.window()); + lws_hint = cl::NDRange(gws[0], gws[1], 1); + } + } + + k.set_lws_hint(lws_hint); +} } // namespace void BifrostTuner::tune_kernel_static(ICLKernel &kernel) { - // Continue on tuning if dynamic tuning if(dynamic_cast(&kernel) != nullptr) { tune_direct_convolution_kernel(*utils::cast::polymorphic_downcast(&kernel)); } + else if(dynamic_cast(&kernel) != nullptr) + { + tune_col2im_kernel(*utils::cast::polymorphic_downcast(&kernel)); + } + else if(dynamic_cast(&kernel) != nullptr) + { + tune_im2col_kernel(*utils::cast::polymorphic_downcast(&kernel)); + } + else if(dynamic_cast(&kernel) != nullptr) + { + tune_depthwise_im2col_kernel(*utils::cast::polymorphic_downcast(&kernel)); + } + else if(dynamic_cast(&kernel) != nullptr) + { + tune_gemv_kernel(*utils::cast::polymorphic_downcast(&kernel)); + } + else if(dynamic_cast(&kernel) != nullptr) + { + tune_gemm_kernel(*utils::cast::polymorphic_downcast(&kernel)); + } + else if(dynamic_cast(&kernel) != nullptr) + { + tune_pooling_kernel(*utils::cast::polymorphic_downcast(&kernel)); + } } void BifrostTuner::tune_kernel_dynamic(ICLKernel &kernel) diff --git a/src/runtime/CL/tuners/MidgardTuner.cpp b/src/runtime/CL/tuners/MidgardTuner.cpp new file mode 100644 index 0000000000..2c4b1ac94c --- /dev/null +++ b/src/runtime/CL/tuners/MidgardTuner.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/runtime/CL/tuners/MIdgardTuner.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernels.h" +#include "arm_compute/core/utils/misc/Cast.h" + +namespace arm_compute +{ +namespace tuners +{ +namespace +{ +void tune_gemm_kernel(CLGEMMMatrixMultiplyKernel &k) +{ + cl::NDRange lws_hint = k.lws_hint(); + const GPUTarget gpu_target = k.get_target(); + + switch(gpu_target) + { + case GPUTarget::MIDGARD: + case GPUTarget::T600: + case GPUTarget::T700: + case GPUTarget::T800: + if(k._output->info()->dimension(1) == 196) + { + lws_hint = cl::NDRange(1, 7); + } + else + { + lws_hint = cl::NDRange(8, 8); + } + break; + default: + lws_hint = cl::NullRange; + } + + k.set_lws_hint(lws_hint); +} +} // namespace + +void MidgardTuner::tune_kernel_static(ICLKernel &kernel) +{ + if(dynamic_cast(&kernel) != nullptr) + { + tune_gemm_kernel(*utils::cast::polymorphic_downcast(&kernel)); + } +} + +void MidgardTuner::tune_kernel_dynamic(ICLKernel &kernel) +{ + ARM_COMPUTE_UNUSED(kernel); +} +} // namespace tuners +} // namespace arm_compute \ No newline at end of file -- cgit v1.2.1