aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp24
1 files changed, 15 insertions, 9 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index aa69ed06d1..fd0978230d 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -96,7 +96,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
const int k = reshape_info.k();
const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
- rhs_info.n0 = 16 / input1->element_size();
+ rhs_info.n0 = max_cl_vector_width / input1->element_size();
rhs_info.k0 = 1;
rhs_info.h0 = mult_transpose1xW_width;
rhs_info.interleave = false;
@@ -354,6 +354,12 @@ void CLGEMMMatrixMultiplyKernel::configure(const CLCompileContext &compile_conte
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure_internal(win_config.second);
+ const unsigned int h_gemm_3d = _reinterpret_output_as_3d ? output->info()->dimension(1) : input0->info()->dimension(1);
+ const unsigned int d_gemm_3d = _reinterpret_output_as_3d ? output->info()->dimension(2) : input0->info()->dimension(2);
+
+ const unsigned int m0 = num_elements_processed.y();
+ const unsigned int n0 = num_elements_processed.x();
+
// Create build options
CLBuildOptions build_opts;
@@ -363,8 +369,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const CLCompileContext &compile_conte
build_opts.add_option_if(reshape_info.broadcast_bias(), "-DBROADCAST_BIAS");
build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D");
- build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
- build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
+ build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(h_gemm_3d));
+ build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(d_gemm_3d));
build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
build_opts.add_option_if(activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(activation_info.activation())));
build_opts.add_option_if(activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(activation_info.a()));
@@ -378,9 +384,9 @@ void CLGEMMMatrixMultiplyKernel::configure(const CLCompileContext &compile_conte
const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
- build_opts.add_option("-DCOLS_B=" + support::cpp11::to_string(input1->info()->dimension(0)));
- build_opts.add_option("-DMULT_TRANSPOSE1XW_WIDTH=" + support::cpp11::to_string(mult_transpose1xW_width));
- build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
+ build_opts.add_option("-DK=" + support::cpp11::to_string(input1->info()->dimension(0) / (n0 * mult_transpose1xW_width)));
+ build_opts.add_option("-DH0=" + support::cpp11::to_string(mult_transpose1xW_width));
+ build_opts.add_option("-DV0=" + support::cpp11::to_string(mult_interleave4x4_height));
if(is_data_type_float(data_type) && is_bifrost)
{
@@ -398,7 +404,7 @@ void CLGEMMMatrixMultiplyKernel::configure(const CLCompileContext &compile_conte
}
else // The input tensors have not been reshaped
{
- build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
+ build_opts.add_option("-DK=" + support::cpp11::to_string(input0->info()->dimension(0)));
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
// Create kernels according to the architecture, data type and input size.
@@ -431,8 +437,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const CLCompileContext &compile_conte
{
kernel_name = "gemm_mm_floating_point";
}
- build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y()));
- build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
+ build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
+ build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
}
// Create kernel