aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClWinogradFilterTransformKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClWinogradFilterTransformKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClWinogradFilterTransformKernel.cpp24
1 files changed, 20 insertions, 4 deletions
diff --git a/src/gpu/cl/kernels/ClWinogradFilterTransformKernel.cpp b/src/gpu/cl/kernels/ClWinogradFilterTransformKernel.cpp
index 4ba6ba8a9a..136376a39f 100644
--- a/src/gpu/cl/kernels/ClWinogradFilterTransformKernel.cpp
+++ b/src/gpu/cl/kernels/ClWinogradFilterTransformKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -108,7 +108,16 @@ void ClWinogradFilterTransformKernel::configure(const ClCompileContext &compile_
// Set build options
CLBuildOptions build_opts;
- build_opts.add_option("-DSRC_DIM_Z=" + support::cpp11::to_string(src->dimension(2)));
+
+ // For NHWC layouts pass tensor dimesions at runtime
+ if(src->data_layout() == DataLayout::NHWC)
+ {
+ _src_dim_z = src->dimension(2);
+ }
+ else
+ {
+ build_opts.add_option("-DSRC_DIM_Z=" + support::cpp11::to_string(src->dimension(2)));
+ }
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src->data_type()));
build_opts.add_option_if(winograd_info.kernel_size.height == 1, "-DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL");
build_opts.add_option_if(winograd_info.kernel_size.width == 1, "-DWINOGRAD_FILTER_TRANSFORM_VERTICAL");
@@ -117,7 +126,10 @@ void ClWinogradFilterTransformKernel::configure(const ClCompileContext &compile_
// Create kernel
std::string kernel_name = "winograd_filter_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string() + "_" + lower_string(string_from_data_layout(src->data_layout()));
- _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
+
+ // A macro guard to compile ONLY the kernel of interest
+ build_opts.add_option("-D" + upper_string(kernel_name));
+ _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
// Configure kernel window
auto win_config = validate_and_configure_window(src, dst);
@@ -149,8 +161,12 @@ void ClWinogradFilterTransformKernel::run_op(ITensorPack &tensors, const Window
unsigned int idx = 0;
add_4D_tensor_argument(idx, src, window);
add_3D_tensor_argument(idx, dst, window_out);
+ if(src->info()->data_layout() == DataLayout::NHWC)
+ {
+ _kernel.setArg<cl_uint>(idx++, _src_dim_z);
+ }
enqueue(queue, *this, window, lws_hint());
}
} // namespace kernels
} // namespace opencl
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute