aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp60
1 files changed, 14 insertions, 46 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp
index 149c92b7a9..04c1cd66c9 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp
@@ -24,11 +24,7 @@
#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "src/core/AccessWindowStatic.h"
#include "src/core/CL/CLUtils.h"
#include "src/core/CL/CLValidate.h"
#include "src/core/helpers/AutoConfiguration.h"
@@ -118,18 +114,15 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info,
- const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info, ElementsProcessed &num_elements_processed)
+Window validate_and_configure_window(ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info, ElementsProcessed &num_elements_processed)
{
+ ARM_COMPUTE_UNUSED(src0, src1, src2);
unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d;
bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0;
- Window win{};
- Window win_out{};
- bool window_changed = false;
-
// In case both input and dst have to be reinterpreted as 3D tensors,
// force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
// This approach should only be used when the input/dst tensors have pad on the y direction
@@ -138,9 +131,6 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src0, ITens
reinterpret_output_as_3d = false;
}
- // dst tensor auto initialization if not yet initialized
- auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)));
-
TensorInfo tmp_info(*dst);
if(reinterpret_output_as_3d)
@@ -156,28 +146,14 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src0, ITens
num_elems_processed_per_iteration_x = rhs_info.n0;
num_elems_processed_per_iteration_y = lhs_info.m0;
- win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- win_out = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
-
- if(src2 != nullptr)
- {
- const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x;
-
- AccessWindowStatic src2_access(src2, 0, 0,
- ceil_to_multiple(src2->dimension(0), bias_processed_per_iteration_x),
- src2->dimension(1));
-
- window_changed = update_window_and_padding(win, src2_access);
- }
+ Window win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
// Collapse along the Z direction
// This collapse needs to be here in order to tune the Z dimension of LWS
- Window collapsed = win;
const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(dst->num_dimensions()), 2u);
- collapsed = win.collapse(win, dimension_to_collapse);
+ Window collapsed = win.collapse(win, dimension_to_collapse);
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, collapsed);
+ return collapsed;
}
} // namespace
@@ -187,7 +163,7 @@ ClGemmMatrixMultiplyReshapedOnlyRhsKernel::ClGemmMatrixMultiplyReshapedOnlyRhsKe
}
void ClGemmMatrixMultiplyReshapedOnlyRhsKernel::configure(const CLCompileContext &compile_context,
- ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float alpha, float beta,
+ const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, float alpha, float beta,
const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
@@ -201,7 +177,10 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsKernel::configure(const CLCompileContext
_export_to_cl_image = rhs_info.export_to_cl_image;
_has_pad_y = gemm_info.has_pad_y;
- auto padding_info = get_padding_info({ src0, src1, dst });
+ // dst tensor auto initialization if not yet initialized
+ auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)));
+
+ auto padding_info = get_padding_info({ src0, src1, src2, dst });
// In case both input and dst have to be reinterpreted as 3D tensors,
// force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
@@ -218,9 +197,9 @@ void ClGemmMatrixMultiplyReshapedOnlyRhsKernel::configure(const CLCompileContext
ElementsProcessed num_elements_processed{};
// Configure kernel window
- auto win_config = validate_and_configure_window(src0, src1, src2, dst, lhs_info, rhs_info, gemm_info, num_elements_processed);
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- ICLKernel::configure_internal(win_config.second);
+ Window win = validate_and_configure_window(src0->clone().get(), src1->clone().get(), (src2 != nullptr) ? src2->clone().get() : nullptr, dst->clone().get(), lhs_info, rhs_info, gemm_info,
+ num_elements_processed);
+ ICLKernel::configure_internal(win);
// If _reinterpret_input_as_3d = reinterpret_output_as_3d = true,
// we will dispatch a batched-GEMM to reduce the complexity of the address calculation within the OpenCL kernel.
@@ -314,18 +293,7 @@ Status ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(const ITensorInfo *sr
const GEMMLHSMatrixInfo &lhs_info,
const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
{
- ElementsProcessed num_elements_processed{};
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src0->clone().get(),
- src1->clone().get(),
- src2 != nullptr ? src2->clone().get() : nullptr,
- dst->clone().get(),
- lhs_info,
- rhs_info,
- gemm_info,
- num_elements_processed)
- .first);
-
return Status{};
}