aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClWinogradInputTransformKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClWinogradInputTransformKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClWinogradInputTransformKernel.cpp45
1 files changed, 35 insertions, 10 deletions
diff --git a/src/gpu/cl/kernels/ClWinogradInputTransformKernel.cpp b/src/gpu/cl/kernels/ClWinogradInputTransformKernel.cpp
index d6b038f0f8..48d806dc7c 100644
--- a/src/gpu/cl/kernels/ClWinogradInputTransformKernel.cpp
+++ b/src/gpu/cl/kernels/ClWinogradInputTransformKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -79,8 +79,30 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
ARM_COMPUTE_UNUSED(output);
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- bool window_changed = false;
- Window win = calculate_max_window(*input, Steps(1, 1));
+ bool window_changed = false;
+ int num_elems_processed_per_iteration = 1;
+
+ if(input->data_layout() == DataLayout::NHWC)
+ {
+ // In the case of FP16 computation, we can perform more
+ // output feature maps in a single work-item.
+ // From experiments, num_elems_processed_per_iteration = 2 looks good for fp16 to
+ // improve the performance. However, in order to make the implementation simpler,
+ // we set num_elems_processed_per_iteration = 2 only when the OFMs are multiple of 2.
+ // Note: At the moment, only Winograd Input Transform 3x3 can support N0 != 1
+ const DataType dt = input->data_type();
+ const size_t dim0 = input->dimension(0);
+ const size_t k_sz = winograd_info.kernel_size.area();
+ const bool cond = dt == DataType::F16 && ((dim0 % 2) == 0);
+ if(cond)
+ {
+ if(k_sz == 3 || k_sz == 9)
+ {
+ num_elems_processed_per_iteration = 2;
+ }
+ }
+ }
+ Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
if(input->data_layout() == DataLayout::NCHW)
{
@@ -143,12 +165,19 @@ void ClWinogradInputTransformKernel::configure(const ClCompileContext &compile_c
ARM_COMPUTE_ERROR_ON(_num_tiles_x * _num_tiles_y != static_cast<int>(dst->dimension(1)));
const size_t total_batches = src->tensor_shape().total_size_upper(3);
+ // Create window and update padding
+ auto win_config = validate_and_configure_window(src, dst, winograd_info);
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ IClKernel::configure_internal(win_config.second, cl::NDRange(1, 1, 8));
+
+ _src_width = src->dimension(idx_w);
+ _src_height = src->dimension(idx_h);
+
CLBuildOptions build_opts;
if(_data_layout == DataLayout::NHWC)
{
build_opts.add_option("-DNHWC");
- _src_width = src->dimension(idx_w);
- _src_height = src->dimension(idx_h);
+ build_opts.add_option("-DN0=" + support::cpp11::to_string(win_config.second.x().step()));
build_opts.add_option("-DPAD_LEFT=" + support::cpp11::to_string(conv_info.pad_left()));
build_opts.add_option("-DPAD_TOP=" + support::cpp11::to_string(conv_info.pad_top()));
build_opts.add_option("-DOUTPUT_TILE_W=" + support::cpp11::to_string(output_tile_size.width));
@@ -156,6 +185,7 @@ void ClWinogradInputTransformKernel::configure(const ClCompileContext &compile_c
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src->data_type()));
build_opts.add_option_if(winograd_info.kernel_size.height == 1, "-DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL");
build_opts.add_option_if(winograd_info.kernel_size.width == 1, "-DWINOGRAD_INPUT_TRANSFORM_VERTICAL");
+ build_opts.add_option_if(total_batches > 1, "-DIS_BATCHED");
}
else
{
@@ -191,11 +221,6 @@ void ClWinogradInputTransformKernel::configure(const ClCompileContext &compile_c
build_opts.add_option("-D" + upper_string(kernel_name));
_kernel = create_kernel(compile_context, kernel_name, build_opts.options());
- // Create window and update padding
- auto win_config = validate_and_configure_window(src, dst, winograd_info);
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- IClKernel::configure_internal(win_config.second, cl::NDRange(1, 1, 8));
-
_border_size = BorderSize(src->padding());
ARM_COMPUTE_ERROR_ON((src->data_layout() == DataLayout::NHWC) && has_padding_changed(padding_info));