aboutsummaryrefslogtreecommitdiff
path: root/src/gpu
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu')
-rw-r--r--src/gpu/cl/ClKernelLibrary.cpp2
-rw-r--r--src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp12
2 files changed, 9 insertions, 5 deletions
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index 5355cb7402..de2e9f9742 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -320,7 +320,9 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map =
{ "l2_normalize_y", "common/l2_normalize.cl" },
{ "l2_normalize_z", "common/l2_normalize.cl" },
{ "mat_mul_native_mmul_nt_nt", "common/mat_mul_mmul.cl" },
+ { "mat_mul_native_mmul_t_nt", "common/mat_mul_mmul.cl" },
{ "mat_mul_native_mmul_nt_t", "common/mat_mul_mmul.cl" },
+ { "mat_mul_native_mmul_t_t", "common/mat_mul_mmul.cl" },
{ "mat_mul_native_nt_nt", "common/mat_mul.cl" },
{ "mat_mul_native_nt_t", "common/mat_mul.cl" },
{ "mat_mul_native_t_nt", "common/mat_mul.cl" },
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
index 06a0bdee17..4630ec08e9 100644
--- a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
+++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
@@ -64,14 +64,17 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
const int n0 = matmul_kernel_info.n0;
const int k0 = matmul_kernel_info.k0;
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((adj_lhs), "adj_lhs is not supported yet");
-
// Validate M0
ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0");
+ if(adj_lhs)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((m0 != 1) && (m0 != 2) && (m0 != 3) && (m0 != 4) && (m0 != 8) && (m0 != 16), "Only 1,2,3,4,8,16 are supported for M0 for Lhs transposed");
+ }
+
// Validate N0
ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 < 1, "Only positive integers are supported for N0");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(((n0 & (n0 - 1)) && (n0 != 3)) || (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((n0 != 1) && (n0 != 2) && (n0 != 3) && (n0 != 4) && (n0 != 8) && (n0 != 16), "Only 1,2,3,4,8,16 are supported for N0");
// Validate K0
ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 != 1), "Only 1 is supported for k0");
@@ -81,8 +84,7 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info)
{
- ARM_COMPUTE_UNUSED(matmul_kernel_info);
- const size_t lhs_k = lhs_shape.x();
+ const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x();
const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y();
ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match.");