diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index bc339f176f..e7ad62f5ff 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -151,7 +151,8 @@ Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITens return Status{}; } -void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info) +void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info, + const Size2D &dilation) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); @@ -160,7 +161,8 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * biases != nullptr ? biases->info() : nullptr, output->info(), conv_info, - weights_info)); + weights_info, + dilation)); _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); @@ -187,7 +189,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * const unsigned int kernel_width = weights->info()->dimension(0); const unsigned int kernel_height = weights->info()->dimension(1); std::tie(conv_w, conv_h) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_width, kernel_height, - conv_info); + conv_info, dilation); unsigned int mat_weights_cols = weights->info()->dimension(3); unsigned int mat_weights_rows = weights->info()->dimension(0) * weights->info()->dimension(1) * weights->info()->dimension(2) + bias_element; @@ -224,7 +226,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * _memory_group.manage(&_gemm_output); // Configure im2col - _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias); + _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation); // Configure GEMM configure_mm(&_im2col_output, weights, &_gemm_output); @@ -260,7 +262,7 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor * } Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, - const WeightsInfo &weights_info) + const WeightsInfo &weights_info, const Size2D &dilation) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!"); @@ -282,7 +284,7 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI const unsigned int kernel_width = weights->dimension(0); const unsigned int kernel_height = weights->dimension(1); - std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height, conv_info); + std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height, conv_info, dilation); unsigned int mat_weights_cols = weights->dimension(3); unsigned int mat_weights_rows = weights->dimension(0) * weights->dimension(1) * weights->dimension(2) + bias_element; @@ -298,7 +300,7 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI shape_im2col.set(2, 1); TensorInfo im2col_reshaped_info(shape_im2col, 1, dt, input->fixed_point_position()); im2col_reshaped_info.set_quantization_info(input->quantization_info()); - CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias); + CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation); // Create GEMM output tensor TensorShape shape_gemm = im2col_reshaped_info.tensor_shape(); |