diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp')
-rw-r--r-- | src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp index aa806978ef..29f9180bf4 100644 --- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp @@ -240,7 +240,9 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsKernel::configure(const CLCompileContext // Calculate partial (store instead of load) M0 and partial N0 for the partial blocks at the end of a row/column if any. This is to avoid padding. const unsigned int partial_store_m0 = internal_m % internal_m0; const unsigned int partial_store_n0 = gemm_info.n % rhs_info.n0; - + _m = internal_m; + _n = gemm_info.n; + _k = gemm_info.k; // Create build options CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src0->data_type())); @@ -253,9 +255,6 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsKernel::configure(const CLCompileContext build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS"); build_opts.add_option_if(rhs_info.export_to_cl_image, "-DOPENCL_IMAGE_SUPPORT"); build_opts.add_option("-DRHS_HEIGHT=" + support::cpp11::to_string(src1->dimension(1))); - build_opts.add_option("-DM=" + support::cpp11::to_string(internal_m)); - build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n)); - build_opts.add_option("-DK=" + support::cpp11::to_string(gemm_info.k)); build_opts.add_option("-DM0=" + support::cpp11::to_string(internal_m0)); build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0)); build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0)); @@ -286,6 +285,9 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsKernel::configure(const CLCompileContext kernel_name += rhs_info.export_to_cl_image ? "_texture" : ""; post_op_utils.set_post_ops_cl_kernel_name(kernel_name, gemm_info.post_ops); + // A macro guard to compile ONLY the kernel of interest + build_opts.add_option("-D" + upper_string(kernel_name)); + // Create kernel _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); @@ -447,6 +449,11 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsKernel::run_op(ITensorPack &tensors, con _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(total_cross_plane_pad_out)); } + // Pass m, n and k at runtime as signed ints, to ensure results of any subractions they could be operand in, would still be signed. + _kernel.setArg<cl_int>(idx++, _m); + _kernel.setArg<cl_int>(idx++, _n); + _kernel.setArg<cl_int>(idx++, _k); + enqueue(queue, *this, slice, lws_hint(), _use_dummy_work_items); } while(window.slide_window_slice_3D(slice)); |