diff options
Diffstat (limited to 'src/core/CL/kernels')
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp | 21 | ||||
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 2 |
2 files changed, 13 insertions, 10 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp index 5a46a1e013..8f20de1ea1 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp @@ -159,23 +159,18 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe num_elems_processed_per_iteration_x = rhs_info.n0; num_elems_processed_per_iteration_y = lhs_info.m0; - // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor - // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic - const unsigned int m = gemm_info.m; - const unsigned int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y; - win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); AccessWindowStatic input0_access(input0, 0, 0, - ceil_to_multiple(input0->dimension(0), num_elems_processed_per_iteration_y), + input0->dimension(0), input0->dimension(1)); AccessWindowStatic input1_access(input1, 0, 0, - ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), + input1->dimension(0), input1->dimension(1)); AccessWindowStatic output_access(output, 0, 0, - ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x), - output->dimension(1) + bottom_pad); + output->dimension(0), + output->dimension(1)); if(input2 != nullptr) { @@ -256,6 +251,12 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi const bool enable_mixed_precision = gemm_info.fp_mixed_precision; const DataType data_type = input0->info()->data_type(); + // Calculate partial (store instead of load) M0 and partial N0 for the partial blocks at the end of a row/column if any. This is to avoid padding. + const unsigned int internal_m = _reinterpret_output_as_3d ? gemm_info.m : output->info()->dimension(1); + + const unsigned int partial_store_m0 = internal_m % lhs_info.m0; + const unsigned int partial_store_n0 = gemm_info.n % rhs_info.n0; + // Create build options CLBuildOptions build_opts; build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha)); @@ -286,6 +287,8 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0)); build_opts.add_option("-DV0=" + support::cpp11::to_string(lhs_info.v0)); build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0)); + build_opts.add_option("-DPARTIAL_STORE_M0=" + support::cpp11::to_string(partial_store_m0)); + build_opts.add_option("-DPARTIAL_STORE_N0=" + support::cpp11::to_string(partial_store_n0)); std::string kernel_name("gemm_mm_reshaped_"); kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_"; diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index e65726b234..cf77c70bfa 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -162,7 +162,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe input0->dimension(0), input0->dimension(1)); AccessWindowStatic input1_access(input1, 0, 0, - ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), + input1->dimension(0), input1->dimension(1)); AccessWindowStatic output_access(output, 0, 0, output->dimension(0), |