diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClMatMulNativeKernel.cpp')
-rw-r--r-- | src/gpu/cl/kernels/ClMatMulNativeKernel.cpp | 25 |
1 files changed, 4 insertions, 21 deletions
diff --git a/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp index 8f8ccfc41f..41ba5d5e25 100644 --- a/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp +++ b/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp @@ -23,20 +23,21 @@ */ #include "src/gpu/cl/kernels/ClMatMulNativeKernel.h" -#include "arm_compute/core/utils/ActivationFunctionUtils.h" #include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/ITensorPack.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/utils/ActivationFunctionUtils.h" +#include "arm_compute/core/utils/StringUtils.h" #include "arm_compute/core/utils/helpers/AdjustVecSize.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "arm_compute/core/utils/StringUtils.h" #include "src/common/utils/Log.h" #include "src/core/CL/CLUtils.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" +#include "src/gpu/cl/kernels/helpers/MatMulKernelHelpers.h" #include "support/Cast.h" #include "support/StringSupport.h" @@ -79,24 +80,6 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info) return Status{}; } -Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info) -{ - const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x(); - const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y(); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape.total_size() == 0, "Lhs tensor can't be empty"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_shape.total_size() == 0, "Rhs tensor can't be empty"); - - constexpr size_t batch_dim_start = 2; - for(size_t i = batch_dim_start; i < Coordinates::num_max_dimensions; ++i) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape[i] != rhs_shape[i], "Batch dimension broadcasting is not supported"); - } - - return Status{}; -} - Status validate_export_to_cl_image(const ITensorInfo *rhs, const MatMulKernelInfo &matmul_kernel_info) { ARM_COMPUTE_RETURN_ERROR_ON(matmul_kernel_info.export_rhs_to_cl_image && rhs->lock_paddings()); @@ -131,7 +114,7 @@ Status ClMatMulNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_export_to_cl_image(rhs, matmul_kernel_info)); const TensorShape expected_output_shape = misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info); |