aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp21
1 files changed, 12 insertions, 9 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index 5a46a1e013..8f20de1ea1 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -159,23 +159,18 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
num_elems_processed_per_iteration_x = rhs_info.n0;
num_elems_processed_per_iteration_y = lhs_info.m0;
- // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
- // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
- const unsigned int m = gemm_info.m;
- const unsigned int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
-
win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
AccessWindowStatic input0_access(input0, 0, 0,
- ceil_to_multiple(input0->dimension(0), num_elems_processed_per_iteration_y),
+ input0->dimension(0),
input0->dimension(1));
AccessWindowStatic input1_access(input1, 0, 0,
- ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x),
+ input1->dimension(0),
input1->dimension(1));
AccessWindowStatic output_access(output, 0, 0,
- ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
- output->dimension(1) + bottom_pad);
+ output->dimension(0),
+ output->dimension(1));
if(input2 != nullptr)
{
@@ -256,6 +251,12 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi
const bool enable_mixed_precision = gemm_info.fp_mixed_precision;
const DataType data_type = input0->info()->data_type();
+ // Calculate partial (store instead of load) M0 and partial N0 for the partial blocks at the end of a row/column if any. This is to avoid padding.
+ const unsigned int internal_m = _reinterpret_output_as_3d ? gemm_info.m : output->info()->dimension(1);
+
+ const unsigned int partial_store_m0 = internal_m % lhs_info.m0;
+ const unsigned int partial_store_n0 = gemm_info.n % rhs_info.n0;
+
// Create build options
CLBuildOptions build_opts;
build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha));
@@ -286,6 +287,8 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi
build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0));
build_opts.add_option("-DV0=" + support::cpp11::to_string(lhs_info.v0));
build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
+ build_opts.add_option("-DPARTIAL_STORE_M0=" + support::cpp11::to_string(partial_store_m0));
+ build_opts.add_option("-DPARTIAL_STORE_N0=" + support::cpp11::to_string(partial_store_n0));
std::string kernel_name("gemm_mm_reshaped_");
kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_";