diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp | 54 |
1 files changed, 16 insertions, 38 deletions
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index fb90415e31..49549a0ad0 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -171,7 +171,6 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * const DataLayout data_layout = input->info()->data_layout(); const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES); const unsigned int kernel_width = weights->info()->dimension(idx_width); @@ -193,7 +192,6 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * ICLTensor *gemm_output_to_use = output; ICLTensor *gemm_output_staged_to_use = output; - const unsigned bias_element = (_append_bias && !_skip_im2col) ? 1 : 0; const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr; // Get parameters from conv_info @@ -212,7 +210,6 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * dilation); unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels); - unsigned int mat_weights_rows = weights->info()->dimension(idx_width) * weights->info()->dimension(idx_height) * weights->info()->dimension(idx_channel) + bias_element; // _weights_reshaped will be auto configured in the kernel. // Just append biases and do not transpose 1xW as it will be reshaped in CLGEMM @@ -223,25 +220,13 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * // Create tensor to store im2col reshaped inputs if(!_skip_im2col) { - // Calculate im2col shape - // For OpenCL the batch size is on the third dimension - // TODO (giaiod01): Use auto-init COMPMID-1277 - TensorShape shape_im2col = input->info()->tensor_shape(); - if(shape_im2col.num_dimensions() >= 3) - { - shape_im2col.remove_dimension(2); - } - shape_im2col.set(0, mat_weights_rows); - shape_im2col.set(1, conv_w * conv_h); - - // FIXME: input->clone() doesn't work with subtensors for grouped convolutions. - TensorInfo im2col_reshaped_info(shape_im2col, 1, data_type); - im2col_reshaped_info.set_quantization_info(input->info()->quantization_info()); - _im2col_output.allocator()->init(im2col_reshaped_info); _memory_group.manage(&_im2col_output); - // Configure and tune im2col + // Configure and tune im2col. im2col output shape is auto-initialized _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation); + + // Set quantization info + _im2col_output.info()->set_quantization_info(input->info()->quantization_info()); CLScheduler::get().tune_kernel_static(_im2col_kernel); // Update GEMM input @@ -350,11 +335,10 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI const ITensorInfo *gemm_output_staged_to_use = output; const ITensorInfo *weights_to_use = weights; - const bool is_nhwc = data_layout == DataLayout::NHWC; - const bool is_quantized = is_data_type_quantized_asymmetric(data_type); - const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !is_quantized; - const bool append_bias = (biases != nullptr) && (!is_quantized); - const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0; + const bool is_nhwc = data_layout == DataLayout::NHWC; + const bool is_quantized = is_data_type_quantized_asymmetric(data_type); + const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1) && !is_quantized; + const bool append_bias = (biases != nullptr) && (!is_quantized); ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel)); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); @@ -391,7 +375,6 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI dilation); unsigned int mat_weights_cols = weights->dimension(idx_kernels); - unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + bias_element; // Output tensor auto inizialitation if not yet initialized ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, is_quantized ? nullptr : biases, nullptr)); @@ -400,19 +383,14 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI if(!skip_im2col) { - // Create tensor info for im2col reshaped inputs - // For OpenCL the batch size is on the third dimension - // TODO (giaiod01): Use auto-init COMPMID-1277 - TensorShape shape_im2col = input->tensor_shape(); - if(input->tensor_shape().num_dimensions() >= 3) - { - shape_im2col.remove_dimension(2); - } - shape_im2col.set(0, mat_weights_rows); - shape_im2col.set(1, conv_w * conv_h); - im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type); - im2col_reshaped_info.set_quantization_info(input->quantization_info()); - ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation)); + const Size2D kernel_dims(kernel_width, kernel_height); + + // Output tensor auto initialization if not yet initialized + TensorShape expected_output_shape = compute_im2col_conv_shape(input, kernel_dims, conv_info, append_bias, dilation, true); + + auto_init_if_empty(im2col_reshaped_info, input->clone()->set_tensor_shape(expected_output_shape)); + + ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, kernel_dims, conv_info, append_bias, dilation)); gemm_input_to_use = &im2col_reshaped_info; } else if(append_bias) |