diff options
Diffstat (limited to 'src/core/CL/kernels/CLPoolingLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLPoolingLayerKernel.cpp | 51 |
1 files changed, 23 insertions, 28 deletions
diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp index 043a4bde04..bc5ff73b63 100644 --- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp +++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp @@ -63,13 +63,11 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c "Unsupported combination of parameters!"); const bool is_global_pooling = pool_info.is_global_pooling(); - const unsigned int pool_size = is_global_pooling ? input->tensor_shape().x() : pool_info.pool_size().width; + const unsigned int pool_size_x = is_global_pooling ? input->tensor_shape().x() : pool_info.pool_size().width; + const unsigned int pool_size_y = is_global_pooling ? input->tensor_shape().y() : pool_info.pool_size().height; - ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_global_pooling && (input->tensor_shape().x() != input->tensor_shape().y()), - "Global pooling is supported only with rectangular inputs!"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_global_pooling && ((pool_info.pad_stride_info().pad().first >= pool_size) || (pool_info.pad_stride_info().pad().second >= pool_size)), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_global_pooling && ((pool_info.pad_stride_info().pad().first >= pool_size_x) || (pool_info.pad_stride_info().pad().second >= pool_size_y)), "Invalid pool size and pool pad combination!"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(pool_info.pool_size().width != pool_info.pool_size().height, "Invalid Pool size, width not equal to height!"); // Checks performed when output is configured if(output->total_size() != 0) @@ -81,8 +79,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c unsigned int pooled_h = 0; std::tie(pooled_w, pooled_h) = scaled_dimensions(input->dimension(0), input->dimension(1), - pool_size, - pool_size, + pool_size_x, + pool_size_y, pool_info.pad_stride_info()); ARM_COMPUTE_RETURN_ERROR_ON_MSG((output->dimension(0) != pooled_w) || (output->dimension(1) != pooled_h), "Invalid output pooling dimensions!"); @@ -99,21 +97,19 @@ std::tuple<Status, Window, CLPoolingConfig> validate_and_configure_window(ITenso int pool_stride_y = 0; unsigned int pooled_w = 0; unsigned int pooled_h = 0; - int pool_size = pool_info.pool_size().width; + int pool_size_x = pool_info.is_global_pooling() ? input->dimension(0) : pool_info.pool_size().width; + int pool_size_y = pool_info.is_global_pooling() ? input->dimension(1) : pool_info.pool_size().height; const PadStrideInfo pad_stride_info = pool_info.pad_stride_info(); std::tie(pool_pad_x, pool_pad_y) = pad_stride_info.pad(); std::tie(pool_stride_x, pool_stride_y) = pad_stride_info.stride(); ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - // Update pool size in case of global pooling - pool_size = pool_info.is_global_pooling() ? input->dimension(0) : pool_size; - // Check output dimensions std::tie(pooled_w, pooled_h) = scaled_dimensions(input->dimension(0), input->dimension(1), - pool_size, - pool_size, + pool_size_x, + pool_size_y, pad_stride_info); auto_init(input, output, pooled_w, pooled_h); @@ -126,23 +122,23 @@ std::tuple<Status, Window, CLPoolingConfig> validate_and_configure_window(ITenso // Change the number of elements processed per iteration // for pooling 3x3 with stride less equal than 3 - const bool can_optimize = (pool_size == 3) && (pool_stride_x <= 3) && !is_data_type_quantized(data_type); + const bool can_optimize = (pool_size_x == 3) && (pool_size_y == 3) && (pool_stride_x <= 3) && !is_data_type_quantized(data_type); const unsigned int num_elems_processed_per_iteration = can_optimize ? 4 : 1; - const int num_elems_read_per_iteration = (num_elems_processed_per_iteration - 1) * pool_stride_x + pool_size; + const int num_elems_read_per_iteration = (num_elems_processed_per_iteration - 1) * pool_stride_x + pool_size_x; // Number of iterations in X dimension const int num_iterations_x = (pooled_w + num_elems_processed_per_iteration - 1) / num_elems_processed_per_iteration; // Upper limit for the number of right/bottom border elements that are accessed const int upper_bound_w = ((num_iterations_x - 1) * num_elems_processed_per_iteration * pool_stride_x - pool_pad_x + num_elems_read_per_iteration) - input_width; - const int upper_bound_h = ((pooled_h - 1) * pool_stride_y - pool_pad_y + pool_size) - input_height; + const int upper_bound_h = ((pooled_h - 1) * pool_stride_y - pool_pad_y + pool_size_y) - input_height; border_size.right = std::max(upper_bound_w, pool_pad_x); border_size.bottom = std::max(upper_bound_h, pool_pad_y); Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration)); - AccessWindowRectangle input_access(input, -pool_pad_x, -pool_pad_y, num_elems_read_per_iteration, pool_size, + AccessWindowRectangle input_access(input, -pool_pad_x, -pool_pad_y, num_elems_read_per_iteration, pool_size_y, pool_stride_x * num_elems_processed_per_iteration, pool_stride_y); AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); bool window_changed = update_window_and_padding(win, input_access, output_access); @@ -172,7 +168,8 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, unsigned int pooled_w = 0; unsigned int pooled_h = 0; const PoolingType pool_type = pool_info.pool_type(); - int pool_size = pool_info.pool_size().width; + const int pool_size_x = pool_info.is_global_pooling() ? input->info()->dimension(0) : pool_info.pool_size().width; + const int pool_size_y = pool_info.is_global_pooling() ? input->info()->dimension(1) : pool_info.pool_size().height; const PadStrideInfo pad_stride_info = pool_info.pad_stride_info(); const bool exclude_padding = pool_info.exclude_padding(); std::tie(pool_pad_x, pool_pad_y) = pad_stride_info.pad(); @@ -180,14 +177,11 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - // Update pool size in case of global pooling - pool_size = pool_info.is_global_pooling() ? input->info()->dimension(0) : pool_size; - // Check output dimensions std::tie(pooled_w, pooled_h) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), - pool_size, - pool_size, + pool_size_x, + pool_size_y, pad_stride_info); auto_init(input->info(), output->info(), pooled_w, pooled_h); @@ -220,22 +214,23 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output, } // Create kernel - if((pool_size == 3) && !is_data_type_quantized_asymmetric(data_type)) + if((pool_size_x == 3) && (pool_size_y == 3) && !is_data_type_quantized_asymmetric(data_type)) { // Check if we have pool3x3 with stride_x less equal than 3. In these cases, run an optimized OpenCL kernel where // each thread computes 4 output elements - const bool is_pool3x3_stride_le3 = (pool_size == 3) && (pool_stride_x <= 3) && !is_data_type_fixed_point(data_type); + const bool is_pool3x3_stride_le3 = (pool_size_x == 3) && (pool_size_y == 3) && (pool_stride_x <= 3) && !is_data_type_fixed_point(data_type); std::string kernel_name = ((is_pool3x3_stride_le3) ? "pooling_layer_optimized_" : "pooling_layer_") - + support::cpp11::to_string(pool_size); + + support::cpp11::to_string(pool_size_x); _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); } else // Run general case { - build_opts.add_option("-DPOOL_SIZE=" + support::cpp11::to_string(pool_size)); + build_opts.add_option("-DPOOL_SIZE_X=" + support::cpp11::to_string(pool_size_x)); + build_opts.add_option("-DPOOL_SIZE_Y=" + support::cpp11::to_string(pool_size_y)); build_opts.add_option_if(data_type == DataType::F16, "-DFP16"); - std::string kernel_name = is_data_type_quantized_asymmetric(data_type) ? "pooling_layer_N_quantized" : "pooling_layer_N"; + std::string kernel_name = is_data_type_quantized_asymmetric(data_type) ? "pooling_layer_MxN_quantized" : "pooling_layer_MxN"; _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); } |