diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp')
-rw-r--r-- | src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp index 64e99332fd..6a450b652b 100644 --- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp @@ -201,7 +201,6 @@ void ClGemmMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi _use_dummy_work_items = preferred_dummy_work_items_support(CLKernelLibrary::get().get_device()); _add_bias = src2 != nullptr; _export_to_cl_image = rhs_info.export_to_cl_image; - _k = gemm_info.k; _num_post_op_args = gemm_info.post_ops.total_num_arguments(); // Check if we need to slide the matrix B @@ -230,6 +229,9 @@ void ClGemmMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi const unsigned int partial_store_m0 = internal_m % lhs_info.m0; const unsigned int partial_store_n0 = gemm_info.n % rhs_info.n0; + _m = gemm_info.m; + _n = gemm_info.n; + _k = gemm_info.k; // Create build options CLBuildOptions build_opts; @@ -250,9 +252,6 @@ void ClGemmMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi build_opts.add_option("-DRHS_HEIGHT=" + support::cpp11::to_string(src1->dimension(1))); build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type)); build_opts.add_option("-DDATA_TYPE_ACCUMULATOR=" + (enable_mixed_precision ? get_cl_type_from_data_type(DataType::F32) : get_cl_type_from_data_type(data_type))); - build_opts.add_option("-DM=" + support::cpp11::to_string(gemm_info.m)); - build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n)); - build_opts.add_option("-DK=" + support::cpp11::to_string(gemm_info.k)); build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0)); build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0)); build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0)); @@ -278,6 +277,9 @@ void ClGemmMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi kernel_name += rhs_info.export_to_cl_image ? "_texture" : ""; post_op_utils.set_post_ops_cl_kernel_name(kernel_name, gemm_info.post_ops); + // A macro guard to compile ONLY the kernel of interest + build_opts.add_option("-D" + upper_string(kernel_name)); + // Create kernel _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); @@ -399,9 +401,6 @@ void ClGemmMatrixMultiplyReshapedKernel::run_op(ITensorPack &tensors, const Wind add_2D_tensor_argument(idx, post_op_arg, slice); } - // K dimension (not used if _export_to_cl_image == true) - _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_k)); - // LHS stride_z _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(src0->info()->strides_in_bytes()[2])); @@ -429,6 +428,13 @@ void ClGemmMatrixMultiplyReshapedKernel::run_op(ITensorPack &tensors, const Wind _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(total_cross_plane_pad)); } + // Pass m, n and k at runtime + _kernel.setArg<cl_int>(idx++, _m); + _kernel.setArg<cl_int>(idx++, _n); + + // K dimension (not used if _export_to_cl_image == true) + _kernel.setArg<cl_int>(idx++, _k); + // Dispatch kernel enqueue(queue, *this, slice, lws_hint(), _use_dummy_work_items); } @@ -436,4 +442,4 @@ void ClGemmMatrixMultiplyReshapedKernel::run_op(ITensorPack &tensors, const Wind } } // namespace kernels } // namespace opencl -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |