aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp73
1 files changed, 59 insertions, 14 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index b667621426..2b004c23db 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -48,8 +48,8 @@ namespace
{
using ElementsProcessed = Steps;
-inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
- bool fp_mixed_precision)
+inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float beta,
+ bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
@@ -61,9 +61,20 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D");
+ const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f;
+ const bool has_vec_c = input2 != nullptr && beta != 0.f;
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(has_vec_c && !is_beta_one, "Adding input2 is only supported for beta equal to 1");
+
if(!is_interleaved_transposed)
{
ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
+
+ if(has_vec_c)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->dimension(0) != input1->dimension(0), "Length of Vector C must match the number of columns of matrix B");
+ }
}
else
{
@@ -101,6 +112,12 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+
+ if(has_vec_c)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor");
+ }
}
if(output->total_size() != 0)
@@ -113,10 +130,11 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
return Status{};
}
-inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output,
- bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target,
+inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output,
+ float beta, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target,
ElementsProcessed &num_elements_processed)
{
+ ARM_COMPUTE_UNUSED(beta);
bool window_changed = false;
Window win{};
Window win_out{};
@@ -126,6 +144,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
bool reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
bool reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0);
+ const bool has_vec_c = input2 != nullptr && beta != 0.f;
// In case both input and output have to be reinterpreted as 3D tensors,
// force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
@@ -176,6 +195,11 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
+ if(has_vec_c)
+ {
+ AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x);
+ window_changed = window_changed || update_window_and_padding(win, input2_access);
+ }
output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
}
@@ -209,6 +233,11 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
+ if(has_vec_c)
+ {
+ AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x);
+ window_changed = window_changed || update_window_and_padding(win, input2_access);
+ }
Coordinates coord;
coord.set_num_dimensions(output->num_dimensions());
@@ -227,20 +256,22 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
} // namespace
CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
- : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false)
+ : _input0(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _has_vec_c(false)
{
}
-void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
- bool fp_mixed_precision)
+void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta,
+ bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
// Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, fp_mixed_precision));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), (input2 != nullptr) ? input2->info() : nullptr, output->info(), beta,
+ is_interleaved_transposed, reshape_info, fp_mixed_precision));
_input0 = input0;
_input1 = input1;
+ _input2 = input2;
_output = output;
_reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
_reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0);
@@ -266,7 +297,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
ElementsProcessed num_elements_processed{};
// Configure kernel window
- auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, gpu_target, num_elements_processed);
+ auto win_config = validate_and_configure_window(input0->info(), input1->info(), (input2 != nullptr) ? input2->info() : nullptr, output->info(), beta, is_interleaved_transposed, reshape_info,
+ gpu_target, num_elements_processed);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure_internal(win_config.second);
@@ -288,6 +320,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
const bool is_bifrost = get_arch_from_target(gpu_target) == GPUTarget::BIFROST;
+ _has_vec_c = input2 != nullptr && beta != 0.f;
+
std::string kernel_name;
if(is_interleaved_transposed)
{
@@ -351,6 +385,9 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
}
+ // Configure matrix C addition if necessary
+ build_opts.add_option_if(_has_vec_c, "-DADD_VEC_C");
+
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
@@ -373,16 +410,18 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
_config_id += (is_interleaved_transposed ? support::cpp11::to_string(input1->info()->dimension(0)) : support::cpp11::to_string(input1->info()->dimension(1)));
}
-Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed,
- const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
+Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta,
+ bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
{
// Note: num_elements_processed will be set in validate_and_configure_window()
ElementsProcessed num_elements_processed{};
ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info, fp_mixed_precision));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, input2, output, beta, is_interleaved_transposed, reshape_info, fp_mixed_precision));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
input1->clone().get(),
+ (input2 != nullptr) ? input2->clone().get() : nullptr,
output->clone().get(),
+ beta,
is_interleaved_transposed,
reshape_info,
gpu_target,
@@ -409,10 +448,12 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
+ const unsigned int num_arguments_vec_c = (_has_vec_c) ? num_arguments_per_1D_tensor() : 0;
+
if(_reinterpret_input_as_3d)
{
// Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor
- const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3;
+ const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + num_arguments_vec_c;
const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom;
_kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
}
@@ -420,7 +461,7 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
if(_reinterpret_output_as_3d)
{
// Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
- const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
+ const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0) + num_arguments_vec_c;
const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom;
_kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
}
@@ -438,6 +479,10 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
unsigned int idx = 0;
add_2D_tensor_argument(idx, _input0, slice);
add_2D_tensor_argument(idx, _input1, slice_b);
+ if(_has_vec_c)
+ {
+ add_1D_tensor_argument(idx, _input2, slice);
+ }
add_2D_tensor_argument(idx, _output, slice);
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input0->info()->strides_in_bytes()[2]));
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input1->info()->strides_in_bytes()[2]));