aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp b/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp
index 778b9b9fa2..b3a03880ed 100644
--- a/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp
@@ -123,10 +123,9 @@ void ClGemmReshapeRhsMatrixKernel::configure(const CLCompileContext &compile_con
CLBuildOptions build_opts;
build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0));
build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
- build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
- build_opts.add_option_if(rhs_info.transpose, "-DTRANSPOSE");
build_opts.add_option_if(rhs_info.interleave, "-DINTERLEAVE");
- build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(src->dimension(1)));
+ build_opts.add_option_if(rhs_info.transpose, "-DRESHAPE_RHS_T");
+ build_opts.add_option_if(!rhs_info.transpose, "-DRESHAPE_RHS_NT");
build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(src->element_size()));
std::string kernel_name("gemm_reshape_rhs_matrix_");
@@ -139,6 +138,9 @@ void ClGemmReshapeRhsMatrixKernel::configure(const CLCompileContext &compile_con
auto win_config = validate_and_configure_window(src, dst, rhs_info);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure_internal(win_config.second);
+
+ unsigned int idx = 2 * num_arguments_per_3d_tensor_nhw();
+ _kernel.setArg<cl_int>(idx++, rhs_info.h0);
}
Status ClGemmReshapeRhsMatrixKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMRHSMatrixInfo &rhs_info)
@@ -164,8 +166,8 @@ void ClGemmReshapeRhsMatrixKernel::run_op(ITensorPack &tensors, const Window &wi
do
{
unsigned int idx = 0;
- add_3D_tensor_argument(idx, src, slice);
- add_3D_tensor_argument(idx, dst, slice);
+ add_3d_tensor_nhw_argument(idx, src);
+ add_3d_tensor_nhw_argument(idx, dst);
enqueue(queue, *this, slice, lws_hint());
}
while(window.slide_window_slice_3D(slice));