diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClNativeMatMulKernel.cpp')
-rw-r--r-- | src/gpu/cl/kernels/ClNativeMatMulKernel.cpp | 59 |
1 files changed, 33 insertions, 26 deletions
diff --git a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp index 6a4db65922..ffbaf49c02 100644 --- a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp +++ b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp @@ -50,28 +50,40 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info) const int k0 = matmul_kernel_info.k0; // Validate M0 - if(!adj_lhs) - { - // We support any positive integer, but will test & benchmark only 1 to 8 because > 8 will not efficient - ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0 for Lhs non-transposed"); - } - else + ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0"); + + if(adj_lhs) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG((m0 & (m0 - 1)) && (m0 != 3) && (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(((m0 & (m0 - 1)) && (m0 != 3)) || (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed"); } // Validate N0 - ARM_COMPUTE_RETURN_ERROR_ON_MSG((n0 & (n0 - 1)) && (n0 != 3) && (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 < 1, "Only positive integers are supported for N0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(((n0 & (n0 - 1)) && (n0 != 3)) || (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0"); // Validate K0 - if(adj_lhs && !adj_rhs) + ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 < 1, "Only positive integers are supported for K0"); + if(!adj_lhs || adj_rhs) { - // We support any positive integer, but will test & benchmark only 1 to 8 because > 8 will not efficient - ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 < 1, "Only positive integers are supported for K0 for Lhs transposed & Rhs non-transposed"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(((k0 & (k0 - 1)) && (k0 != 3)) || (k0 > 16), "Only 1,2,3,4,8,16 are supported for K0"); } - else + + return Status{}; +} + +Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info) +{ + const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x(); + const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y(); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match."); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape.total_size() == 0, "Lhs tensor can't be empty"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_shape.total_size() == 0, "Rhs tensor can't be empty"); + + constexpr size_t batch_dim_start = 2; + for(size_t i = batch_dim_start; i < Coordinates::num_max_dimensions; ++i) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 & (k0 - 1)) && (k0 != 3) && (k0 > 16), "Only 1,2,3,4,8,16 are supported for K0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape[i] != rhs_shape[i], "Batch dimension broadcasting is not supported"); } return Status{}; @@ -87,15 +99,14 @@ Status ClNativeMatMulKernel::validate(const ITensorInfo *lhs, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); if(output->total_size() != 0) { - const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_batchmatmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, output); } - ARM_COMPUTE_RETURN_ERROR_ON_MSG(matmul_kernel_info.adj_lhs && matmul_kernel_info.adj_rhs, "LHS T and RHS T not implemented"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(matmul_kernel_info.adj_lhs && !matmul_kernel_info.adj_rhs, "LHS T and RHS NT not implemented"); return Status{}; } @@ -105,14 +116,15 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output, matmul_kernel_info); // output tensor auto initialization if not yet initialized - auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_batchmatmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); + auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, output, matmul_kernel_info)); - const int m = output->dimension(1); - const int n = output->dimension(0); - const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x(); + const int m = output->dimension(1); + const int n = output->dimension(0); + const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x(); + const bool adj_lhs = matmul_kernel_info.adj_lhs; - int m0 = std::min(matmul_kernel_info.m0, m); + int m0 = adj_lhs ? adjust_vec_size(matmul_kernel_info.m0, m) : std::min(matmul_kernel_info.m0, m); int n0 = adjust_vec_size(matmul_kernel_info.n0, n); // Configure kernel window @@ -137,11 +149,6 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt"; kernel_name += matmul_kernel_info.adj_rhs ? "_t" : "_nt"; - if(matmul_kernel_info.adj_lhs) - { - ARM_COMPUTE_ERROR("Only Implemented LHS non-transposed kernels"); - } - // A macro guard to compile ONLY the kernel of interest build_opts.add_option("-D" + upper_string(kernel_name)); |