From 215b4ea6c9dee480a22070d5873b0b8cb52531a0 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 28 Jun 2018 16:29:29 +0100 Subject: COMPMID-1277 - Optimizing CLIm2ColKernel for NHWC. This patch includes: - Im2Col optimizations for NHWC using a new data layout - Refactoring of CLIm2ColKernel adding validation method and auto-init - Removed im2col_reduced from CLIm2ColKernel and created a new kernel CLFlattenLayerKernel Change-Id: I1620640b6796baa268324b33ae92cdd8de53e27c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141241 Tested-by: Jenkins Reviewed-by: Giorgio Arena --- src/runtime/CL/functions/CLFlattenLayer.cpp | 12 +++-- src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 34 +++++++------- src/runtime/CL/functions/CLGEMM.cpp | 3 +- .../CL/functions/CLGEMMConvolutionLayer.cpp | 54 +++++++--------------- .../NEON/functions/NEFullyConnectedLayer.cpp | 6 +-- .../NEON/functions/NEGEMMConvolutionLayer.cpp | 4 +- 6 files changed, 47 insertions(+), 66 deletions(-) (limited to 'src/runtime') diff --git a/src/runtime/CL/functions/CLFlattenLayer.cpp b/src/runtime/CL/functions/CLFlattenLayer.cpp index f5809a218a..b372c35dd9 100644 --- a/src/runtime/CL/functions/CLFlattenLayer.cpp +++ b/src/runtime/CL/functions/CLFlattenLayer.cpp @@ -23,8 +23,7 @@ */ #include "arm_compute/runtime/CL/functions/CLFlattenLayer.h" -#include "arm_compute/core/CL/kernels/CLIm2ColKernel.h" -#include "arm_compute/core/Size2D.h" +#include "arm_compute/core/CL/kernels/CLFlattenLayerKernel.h" #include "arm_compute/runtime/CL/CLScheduler.h" #include "support/ToolchainSupport.h" @@ -32,8 +31,13 @@ using namespace arm_compute; void CLFlattenLayer::configure(const ICLTensor *input, ICLTensor *output) { - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(input, output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false); + auto k = arm_compute::support::cpp14::make_unique(); + k->configure(input, output); _kernel = std::move(k); CLScheduler::get().tune_kernel_static(*_kernel); } + +Status CLFlattenLayer::validate(const ITensorInfo *input, const ITensorInfo *output) +{ + return CLFlattenLayerKernel::validate(input, output); +} \ No newline at end of file diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 6fd78a3fc9..60c28a0874 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -73,12 +73,11 @@ Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c } CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr memory_manager) - : _memory_group(memory_manager), _im2col_kernel(), _convert_weights(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), - _accumulate_biases_kernel(), _im2col_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true), + : _memory_group(memory_manager), _convert_weights(), _flatten_layer(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), + _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr) { } - void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output) { if(_is_quantized) @@ -111,20 +110,19 @@ void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLT // If the fully connected layer is called after a convolution layer, the input tensor must be linearized - // Initialize output tensor for im2col - TensorShape shape_im2col = compute_im2col_fc_shape(input->info()); - _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col).set_data_layout(DataLayout::NCHW)); + // Initialize output tensor for flatten + TensorShape shape_flatten = compute_flatten_shape(input->info()); + _flatten_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_flatten).set_data_layout(DataLayout::NCHW)); - // Configure im2col kernel - _memory_group.manage(&_im2col_output); - _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false); - CLScheduler::get().tune_kernel_static(_im2col_kernel); + // Configure flatten kernel + _memory_group.manage(&_flatten_output); + _flatten_layer.configure(input, &_flatten_output); // Configure matrix multiply kernel - configure_mm(&_im2col_output, weights, output); + configure_mm(&_flatten_output, weights, output); - // Allocate the output tensor for im2col once all the configure methods have been called - _im2col_output.allocator()->allocate(); + // Allocate the output tensor for flatten once all the configure methods have been called + _flatten_output.allocator()->allocate(); } void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output) @@ -254,7 +252,7 @@ Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn bool is_quantized = is_data_type_quantized_asymmetric(input->data_type()); const GPUTarget gpu_target = CLScheduler::get().target(); - const ITensorInfo &im2col_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input)).set_data_layout(DataLayout::NCHW)); + const ITensorInfo &flatten_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input)).set_data_layout(DataLayout::NCHW)); const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights))); const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone()); const ITensorInfo &gemmlowp_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32)); @@ -311,9 +309,9 @@ Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn // Fully Connected layer after a Convolution Layer without batches ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (input->dimension(0) * input->dimension(1) * input->dimension(2)))); - // Validate im2col kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_input, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false)); - input_to_use = &im2col_input; + // Validate flatten kernel + ARM_COMPUTE_RETURN_ON_ERROR(CLFlattenLayer::validate(input, &flatten_input)); + input_to_use = &flatten_input; } else { @@ -341,7 +339,7 @@ void CLFullyConnectedLayer::run() // Linearize input if it comes from a convolutional layer if(_is_fc_after_conv) { - CLScheduler::get().enqueue(_im2col_kernel, false); + _flatten_layer.run(); } // Run matrix multiply diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 1d1b17bbf1..a8d7058f2a 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -171,6 +171,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_UNUSED(output); // Check if we need to reshape the matrix B only on the first run const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); @@ -180,7 +181,7 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso TensorInfo tmp_a_info{}; TensorInfo tmp_b_info{}; - TensorInfo tmp_output_info = *output->clone(); + TensorInfo tmp_output_info{}; // Get the GPU target const GPUTarget gpu_target = CLScheduler::get().target(); diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index fb90415e31..49549a0ad0 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -171,7 +171,6 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * const DataLayout data_layout = input->info()->data_layout(); const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES); const unsigned int kernel_width = weights->info()->dimension(idx_width); @@ -193,7 +192,6 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * ICLTensor *gemm_output_to_use = output; ICLTensor *gemm_output_staged_to_use = output; - const unsigned bias_element = (_append_bias && !_skip_im2col) ? 1 : 0; const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr; // Get parameters from conv_info @@ -212,7 +210,6 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * dilation); unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels); - unsigned int mat_weights_rows = weights->info()->dimension(idx_width) * weights->info()->dimension(idx_height) * weights->info()->dimension(idx_channel) + bias_element; // _weights_reshaped will be auto configured in the kernel. // Just append biases and do not transpose 1xW as it will be reshaped in CLGEMM @@ -223,25 +220,13 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * // Create tensor to store im2col reshaped inputs if(!_skip_im2col) { - // Calculate im2col shape - // For OpenCL the batch size is on the third dimension - // TODO (giaiod01): Use auto-init COMPMID-1277 - TensorShape shape_im2col = input->info()->tensor_shape(); - if(shape_im2col.num_dimensions() >= 3) - { - shape_im2col.remove_dimension(2); - } - shape_im2col.set(0, mat_weights_rows); - shape_im2col.set(1, conv_w * conv_h); - - // FIXME: input->clone() doesn't work with subtensors for grouped convolutions. - TensorInfo im2col_reshaped_info(shape_im2col, 1, data_type); - im2col_reshaped_info.set_quantization_info(input->info()->quantization_info()); - _im2col_output.allocator()->init(im2col_reshaped_info); _memory_group.manage(&_im2col_output); - // Configure and tune im2col + // Configure and tune im2col. im2col output shape is auto-initialized _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation); + + // Set quantization info + _im2col_output.info()->set_quantization_info(input->info()->quantization_info()); CLScheduler::get().tune_kernel_static(_im2col_kernel); // Update GEMM input @@ -350,11 +335,10 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI const ITensorInfo *gemm_output_staged_to_use = output; const ITensorInfo *weights_to_use = weights; - const bool is_nhwc = data_layout == DataLayout::NHWC; - const bool is_quantized = is_data_type_quantized_asymmetric(data_type); - const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !is_quantized; - const bool append_bias = (biases != nullptr) && (!is_quantized); - const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0; + const bool is_nhwc = data_layout == DataLayout::NHWC; + const bool is_quantized = is_data_type_quantized_asymmetric(data_type); + const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !is_quantized; + const bool append_bias = (biases != nullptr) && (!is_quantized); ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel)); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); @@ -391,7 +375,6 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI dilation); unsigned int mat_weights_cols = weights->dimension(idx_kernels); - unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + bias_element; // Output tensor auto inizialitation if not yet initialized ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, is_quantized ? nullptr : biases, nullptr)); @@ -400,19 +383,14 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI if(!skip_im2col) { - // Create tensor info for im2col reshaped inputs - // For OpenCL the batch size is on the third dimension - // TODO (giaiod01): Use auto-init COMPMID-1277 - TensorShape shape_im2col = input->tensor_shape(); - if(input->tensor_shape().num_dimensions() >= 3) - { - shape_im2col.remove_dimension(2); - } - shape_im2col.set(0, mat_weights_rows); - shape_im2col.set(1, conv_w * conv_h); - im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type); - im2col_reshaped_info.set_quantization_info(input->quantization_info()); - ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation)); + const Size2D kernel_dims(kernel_width, kernel_height); + + // Output tensor auto initialization if not yet initialized + TensorShape expected_output_shape = compute_im2col_conv_shape(input, kernel_dims, conv_info, append_bias, dilation, true); + + auto_init_if_empty(im2col_reshaped_info, input->clone()->set_tensor_shape(expected_output_shape)); + + ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, kernel_dims, conv_info, append_bias, dilation)); gemm_input_to_use = &im2col_reshaped_info; } else if(append_bias) diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 25b8adc431..c2f0283d4e 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -113,7 +113,7 @@ void NEFullyConnectedLayer::configure_conv_fc(const ITensor *input, const ITenso // If the fully connected layer is called after a convolution layer, the input tensor must be linearized // Initialize output tensor for im2col - TensorShape shape_im2col = compute_im2col_fc_shape(input->info()); + TensorShape shape_im2col = compute_flatten_shape(input->info()); _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); // Configure im2col kernel @@ -249,7 +249,7 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn bool is_fc_after_conv = true; bool is_quantized = is_data_type_quantized_asymmetric(input->data_type()); - const ITensorInfo &im2col_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_im2col_fc_shape(input))); + const ITensorInfo &im2col_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input))); const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights))); const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone()); const ITensorInfo &gemmlowp_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_data_type(DataType::S32)); @@ -420,4 +420,4 @@ void NEFullyConnectedLayer::prepare() _is_prepared = true; } -} +} \ No newline at end of file diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index c0a5d0a436..df4a040bad 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -223,7 +223,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig { // Calculate im2col shape // For NEON the batch size is on the fourth dimension - // TODO (giaiod01): Use auto-init COMPMID-1277 + // TODO (giaiod01): Auto-initialize the output shape of im2col COMPMID-1482 TensorShape shape_im2col = input->info()->tensor_shape(); shape_im2col.set(0, mat_weights_rows); shape_im2col.set(1, conv_w * conv_h); @@ -232,7 +232,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig _im2col_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col)); _memory_group.manage(&_im2col_output); - // Configure and tune im2col + // Configure _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, false, false, dilation); // Update GEMM input -- cgit v1.2.1