diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp')
-rw-r--r-- | src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp index af794354c3..05988997e7 100644 --- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp @@ -275,6 +275,9 @@ void ClGemmMatrixMultiplyNativeKernel::configure(const CLCompileContext &compile // Shrink M0 to be always <= M (internal_m) to prevent out-of-bounds reads. // NOTE: This might have implications on heuristics and performance const unsigned int internal_m0 = std::min(internal_m, lhs_info.m0); + _m = internal_m; + _n = gemm_info.n; + _k = gemm_info.k; // Create build options CLBuildOptions build_opts; @@ -289,9 +292,6 @@ void ClGemmMatrixMultiplyNativeKernel::configure(const CLCompileContext &compile build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(d_gemm_3d)); build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(src1->dimension(2))); build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS"); - build_opts.add_option("-DM=" + support::cpp11::to_string(internal_m)); - build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n)); - build_opts.add_option("-DK=" + support::cpp11::to_string(gemm_info.k)); build_opts.add_option("-DM0=" + support::cpp11::to_string(internal_m0)); build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0)); build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0)); @@ -312,6 +312,9 @@ void ClGemmMatrixMultiplyNativeKernel::configure(const CLCompileContext &compile std::string kernel_name("gemm_mm_native"); post_op_utils.set_post_ops_cl_kernel_name(kernel_name, gemm_info.post_ops); + // A macro guard to compile ONLY the kernel of interest + build_opts.add_option("-D" + upper_string(kernel_name)); + // Create kernel _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); @@ -392,11 +395,11 @@ void ClGemmMatrixMultiplyNativeKernel::run_op(ITensorPack &tensors, const Window unsigned int idx0; if(_add_bias) { - idx0 = (4 + _num_post_op_args) * num_arguments_per_2D_tensor() + (4 + _num_post_op_args); + idx0 = (4 + _num_post_op_args) * num_arguments_per_2D_tensor() + (7 + _num_post_op_args); } else { - idx0 = (3 + _num_post_op_args) * num_arguments_per_2D_tensor() + (3 + _num_post_op_args); + idx0 = (3 + _num_post_op_args) * num_arguments_per_2D_tensor() + (6 + _num_post_op_args); } const unsigned int total_cross_plane_pad = src0->info()->padding().top + src0->info()->padding().bottom; _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad)); @@ -408,11 +411,11 @@ void ClGemmMatrixMultiplyNativeKernel::run_op(ITensorPack &tensors, const Window unsigned int idx0; if(_add_bias) { - idx0 = (4 + _num_post_op_args) * num_arguments_per_2D_tensor() + 4 + (_reinterpret_input_as_3d ? 1 : 0) + _num_post_op_args; + idx0 = (4 + _num_post_op_args) * num_arguments_per_2D_tensor() + 7 + (_reinterpret_input_as_3d ? 1 : 0) + _num_post_op_args; } else { - idx0 = (3 + _num_post_op_args) * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0) + _num_post_op_args; + idx0 = (3 + _num_post_op_args) * num_arguments_per_2D_tensor() + 6 + (_reinterpret_input_as_3d ? 1 : 0) + _num_post_op_args; } const unsigned int total_cross_plane_pad = dst->info()->padding().top + dst->info()->padding().bottom; _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad)); @@ -455,6 +458,12 @@ void ClGemmMatrixMultiplyNativeKernel::run_op(ITensorPack &tensors, const Window const auto post_op_arg = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(experimental::get_post_op_arg_type(i))); _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(post_op_arg->info()->strides_in_bytes()[2])); } + + // Pass m, n and k at runtime + _kernel.setArg<cl_int>(idx++, _m); + _kernel.setArg<cl_int>(idx++, _n); + _kernel.setArg<cl_int>(idx++, _k); + enqueue(queue, *this, slice, lws_hint(), _use_dummy_work_items); } while(window.slide_window_slice_3D(slice)); |