aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClNativeMatMulKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClNativeMatMulKernel.cpp59
1 files changed, 33 insertions, 26 deletions
diff --git a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp
index 6a4db65922..ffbaf49c02 100644
--- a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp
+++ b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp
@@ -50,28 +50,40 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
const int k0 = matmul_kernel_info.k0;
// Validate M0
- if(!adj_lhs)
- {
- // We support any positive integer, but will test & benchmark only 1 to 8 because > 8 will not efficient
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0 for Lhs non-transposed");
- }
- else
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0");
+
+ if(adj_lhs)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((m0 & (m0 - 1)) && (m0 != 3) && (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(((m0 & (m0 - 1)) && (m0 != 3)) || (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed");
}
// Validate N0
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((n0 & (n0 - 1)) && (n0 != 3) && (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 < 1, "Only positive integers are supported for N0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(((n0 & (n0 - 1)) && (n0 != 3)) || (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0");
// Validate K0
- if(adj_lhs && !adj_rhs)
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 < 1, "Only positive integers are supported for K0");
+ if(!adj_lhs || adj_rhs)
{
- // We support any positive integer, but will test & benchmark only 1 to 8 because > 8 will not efficient
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 < 1, "Only positive integers are supported for K0 for Lhs transposed & Rhs non-transposed");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(((k0 & (k0 - 1)) && (k0 != 3)) || (k0 > 16), "Only 1,2,3,4,8,16 are supported for K0");
}
- else
+
+ return Status{};
+}
+
+Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info)
+{
+ const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x();
+ const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y();
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape.total_size() == 0, "Lhs tensor can't be empty");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_shape.total_size() == 0, "Rhs tensor can't be empty");
+
+ constexpr size_t batch_dim_start = 2;
+ for(size_t i = batch_dim_start; i < Coordinates::num_max_dimensions; ++i)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 & (k0 - 1)) && (k0 != 3) && (k0 > 16), "Only 1,2,3,4,8,16 are supported for K0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape[i] != rhs_shape[i], "Batch dimension broadcasting is not supported");
}
return Status{};
@@ -87,15 +99,14 @@ Status ClNativeMatMulKernel::validate(const ITensorInfo *lhs, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
if(output->total_size() != 0)
{
- const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_batchmatmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
+ const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, output);
}
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(matmul_kernel_info.adj_lhs && matmul_kernel_info.adj_rhs, "LHS T and RHS T not implemented");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(matmul_kernel_info.adj_lhs && !matmul_kernel_info.adj_rhs, "LHS T and RHS NT not implemented");
return Status{};
}
@@ -105,14 +116,15 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT
ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output, matmul_kernel_info);
// output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_batchmatmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)));
+ auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)));
ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, output, matmul_kernel_info));
- const int m = output->dimension(1);
- const int n = output->dimension(0);
- const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x();
+ const int m = output->dimension(1);
+ const int n = output->dimension(0);
+ const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x();
+ const bool adj_lhs = matmul_kernel_info.adj_lhs;
- int m0 = std::min(matmul_kernel_info.m0, m);
+ int m0 = adj_lhs ? adjust_vec_size(matmul_kernel_info.m0, m) : std::min(matmul_kernel_info.m0, m);
int n0 = adjust_vec_size(matmul_kernel_info.n0, n);
// Configure kernel window
@@ -137,11 +149,6 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT
kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt";
kernel_name += matmul_kernel_info.adj_rhs ? "_t" : "_nt";
- if(matmul_kernel_info.adj_lhs)
- {
- ARM_COMPUTE_ERROR("Only Implemented LHS non-transposed kernels");
- }
-
// A macro guard to compile ONLY the kernel of interest
build_opts.add_option("-D" + upper_string(kernel_name));