diff options
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp | 73 |
1 files changed, 59 insertions, 14 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index b667621426..2b004c23db 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -48,8 +48,8 @@ namespace { using ElementsProcessed = Steps; -inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, - bool fp_mixed_precision) +inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float beta, + bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0); @@ -61,9 +61,20 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D"); + const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f; + const bool has_vec_c = input2 != nullptr && beta != 0.f; + ARM_COMPUTE_RETURN_ERROR_ON_MSG(has_vec_c && !is_beta_one, "Adding input2 is only supported for beta equal to 1"); + if(!is_interleaved_transposed) { ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1)); + + if(has_vec_c) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->dimension(0) != input1->dimension(0), "Length of Vector C must match the number of columns of matrix B"); + } } else { @@ -101,6 +112,12 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1); + + if(has_vec_c) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor"); + } } if(output->total_size() != 0) @@ -113,10 +130,11 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i return Status{}; } -inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output, - bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, +inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, + float beta, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, ElementsProcessed &num_elements_processed) { + ARM_COMPUTE_UNUSED(beta); bool window_changed = false; Window win{}; Window win_out{}; @@ -126,6 +144,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1]; bool reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d(); bool reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0); + const bool has_vec_c = input2 != nullptr && beta != 0.f; // In case both input and output have to be reinterpreted as 3D tensors, // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. @@ -176,6 +195,11 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor + if(has_vec_c) + { + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x); + window_changed = window_changed || update_window_and_padding(win, input2_access); + } output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape())); } @@ -209,6 +233,11 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor + if(has_vec_c) + { + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x); + window_changed = window_changed || update_window_and_padding(win, input2_access); + } Coordinates coord; coord.set_num_dimensions(output->num_dimensions()); @@ -227,20 +256,22 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu } // namespace CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel() - : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false) + : _input0(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _has_vec_c(false) { } -void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, - bool fp_mixed_precision) +void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, + bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision) { ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output); // Perform validate step - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, fp_mixed_precision)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), (input2 != nullptr) ? input2->info() : nullptr, output->info(), beta, + is_interleaved_transposed, reshape_info, fp_mixed_precision)); _input0 = input0; _input1 = input1; + _input2 = input2; _output = output; _reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d(); _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0); @@ -266,7 +297,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen ElementsProcessed num_elements_processed{}; // Configure kernel window - auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, gpu_target, num_elements_processed); + auto win_config = validate_and_configure_window(input0->info(), input1->info(), (input2 != nullptr) ? input2->info() : nullptr, output->info(), beta, is_interleaved_transposed, reshape_info, + gpu_target, num_elements_processed); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); ICLKernel::configure_internal(win_config.second); @@ -288,6 +320,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen const bool is_bifrost = get_arch_from_target(gpu_target) == GPUTarget::BIFROST; + _has_vec_c = input2 != nullptr && beta != 0.f; + std::string kernel_name; if(is_interleaved_transposed) { @@ -351,6 +385,9 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x())); } + // Configure matrix C addition if necessary + build_opts.add_option_if(_has_vec_c, "-DADD_VEC_C"); + // Create kernel _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); @@ -373,16 +410,18 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen _config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1))); } -Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed, - const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision) +Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, + bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision) { // Note: num_elements_processed will be set in validate_and_configure_window() ElementsProcessed num_elements_processed{}; ARM_COMPUTE_UNUSED(alpha); - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info, fp_mixed_precision)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, input2, output, beta, is_interleaved_transposed, reshape_info, fp_mixed_precision)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), + (input2 != nullptr) ? input2->clone().get() : nullptr, output->clone().get(), + beta, is_interleaved_transposed, reshape_info, gpu_target, @@ -409,10 +448,12 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1)); slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1)); + const unsigned int num_arguments_vec_c = (_has_vec_c) ? num_arguments_per_1D_tensor() : 0; + if(_reinterpret_input_as_3d) { // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor - const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3; + const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + num_arguments_vec_c; const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom; _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad)); } @@ -420,7 +461,7 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que if(_reinterpret_output_as_3d) { // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor - const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0); + const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0) + num_arguments_vec_c; const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom; _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad)); } @@ -438,6 +479,10 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que unsigned int idx = 0; add_2D_tensor_argument(idx, _input0, slice); add_2D_tensor_argument(idx, _input1, slice_b); + if(_has_vec_c) + { + add_1D_tensor_argument(idx, _input2, slice); + } add_2D_tensor_argument(idx, _output, slice); _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[2])); _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[2])); |