aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels')
-rw-r--r--src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp26
-rw-r--r--src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h1
2 files changed, 15 insertions, 12 deletions
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
index 32e69cabda..06a0bdee17 100644
--- a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
+++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
@@ -60,12 +60,11 @@ inline std::pair<int, int> adjust_m0_n0(int m0, int n0, int m, int n)
Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
{
const bool adj_lhs = matmul_kernel_info.adj_lhs;
- const bool adj_rhs = matmul_kernel_info.adj_rhs;
- const int m0 = matmul_kernel_info.m0;
- const int n0 = matmul_kernel_info.n0;
- const int k0 = matmul_kernel_info.k0;
+ const int m0 = matmul_kernel_info.m0;
+ const int n0 = matmul_kernel_info.n0;
+ const int k0 = matmul_kernel_info.k0;
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((adj_lhs || adj_rhs), "adj_lhs and adj_rhs are not supported yet");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((adj_lhs), "adj_lhs is not supported yet");
// Validate M0
ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0");
@@ -84,7 +83,7 @@ Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rh
{
ARM_COMPUTE_UNUSED(matmul_kernel_info);
const size_t lhs_k = lhs_shape.x();
- const size_t rhs_k = rhs_shape.y();
+ const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y();
ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match.");
ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR((lhs_k % mmul_k0) != 0, "K dimension must be a multiple of %d", mmul_k0);
@@ -177,9 +176,11 @@ void ClMatMulNativeMMULKernel::configure(const ClCompileContext &compile_context
const int m = dst->dimension(1);
const int n = dst->dimension(0);
- const int k = lhs->tensor_shape().x();
- _m = m;
- _n = n;
+ const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x();
+
+ _m = m;
+ _n = n;
+ _k = k;
int m0{};
int n0{};
@@ -199,15 +200,15 @@ void ClMatMulNativeMMULKernel::configure(const ClCompileContext &compile_context
build_opts.add_option_if(lhs->data_type() == DataType::F16, "-DHALF_PRECISION");
build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
- build_opts.add_option("-DK0=" + support::cpp11::to_string(matmul_kernel_info.k0));
build_opts.add_option("-DM0_LEFTOVER=" + support::cpp11::to_string(m0_leftover));
build_opts.add_option("-DN0_LEFTOVER=" + support::cpp11::to_string(n0_leftover));
build_opts.add_option("-DMMUL_M0=" + support::cpp11::to_string(mmul_m0));
build_opts.add_option("-DMMUL_N0=" + support::cpp11::to_string(mmul_n0));
build_opts.add_option("-DMMUL_K0=" + support::cpp11::to_string(mmul_k0));
- build_opts.add_option("-DK=" + support::cpp11::to_string(k));
- std::string kernel_name("mat_mul_native_mmul_nt_nt");
+ std::string kernel_name("mat_mul_native_mmul");
+ kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt";
+ kernel_name += matmul_kernel_info.adj_rhs ? "_t" : "_nt";
// A macro guard to compile ONLY the kernel of interest
build_opts.add_option("-D" + upper_string(kernel_name));
@@ -250,6 +251,7 @@ void ClMatMulNativeMMULKernel::run_op(ITensorPack &tensors, const Window &window
// Pass m and n at runtime as signed ints, to ensure results of any subtractions they could be operand in, would still be signed.
_kernel.setArg<cl_int>(idx++, _m);
_kernel.setArg<cl_int>(idx++, _n);
+ _kernel.setArg<cl_int>(idx++, _k);
// LWS_x should be multiple of 16 at least. (32, 2) has been chosen to have more work-items on a single core
// LWS also enforces the order of execution of the work items which improves cache utilization
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h
index 26fe08c466..79f675d03b 100644
--- a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h
+++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h
@@ -86,6 +86,7 @@ public:
private:
int _m{ 1 };
int _n{ 1 };
+ int _k{ 1 };
};
} // namespace kernels
} // namespace opencl