aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp46
1 files changed, 41 insertions, 5 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
index 7ad3d55fe0..c3efc24fa9 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
@@ -134,12 +134,15 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src0, ITens
const GEMMRHSMatrixInfo &rhs_info,
const GEMMKernelInfo &gemm_info, ElementsProcessed &num_elements_processed)
{
- ARM_COMPUTE_UNUSED(src0, src1, src2);
unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d;
bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0;
+ Window win{};
+ Window win_out{};
+ bool window_changed = false;
+
// In case both input and dst have to be reinterpreted as 3D tensors,
// force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
if(reinterpret_input_as_3d == reinterpret_output_as_3d)
@@ -147,6 +150,9 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src0, ITens
reinterpret_output_as_3d = false;
}
+ // dst tensor auto initialization if not yet initialized
+ auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)));
+
TensorInfo tmp_info(*dst);
if(reinterpret_output_as_3d)
@@ -162,14 +168,44 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src0, ITens
num_elems_processed_per_iteration_x = rhs_info.n0;
num_elems_processed_per_iteration_y = lhs_info.m0;
- Window win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win_out = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+
+ AccessWindowStatic src0_access(src0, 0, 0,
+ src0->dimension(0),
+ src0->dimension(1));
+ AccessWindowStatic src1_access(src1, 0, 0,
+ ceil_to_multiple(src1->dimension(0), num_elems_processed_per_iteration_x),
+ src1->dimension(1));
+ AccessWindowStatic dst_access(dst, 0, 0,
+ dst->dimension(0),
+ dst->dimension(1));
+
+ if(src2 != nullptr)
+ {
+ const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x;
+
+ AccessWindowStatic src2_access(src2, 0, 0,
+ ceil_to_multiple(src2->dimension(0), bias_processed_per_iteration_x),
+ src2->dimension(1));
+
+ window_changed = update_window_and_padding(win, src0_access, src1_access, src2_access) || // window used by the execute_window_loop
+ update_window_and_padding(win_out, dst_access); // window used to update the padding requirements of dst tensor
+ }
+ else
+ {
+ window_changed = update_window_and_padding(win, src0_access, src1_access) || // window used by the execute_window_loop
+ update_window_and_padding(win_out, dst_access); // window used to update the padding requirements of dst tensor
+ }
// Collapse along the Z direction
// This collapse needs to be here in order to tune the Z dimension of LWS
+ Window collapsed = win;
const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(dst->num_dimensions()), 2u);
- Window collapsed = win.collapse(win, dimension_to_collapse);
+ collapsed = win.collapse(win, dimension_to_collapse);
- return std::make_pair(Status{}, collapsed);
+ Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, collapsed);
}
} // namespace
@@ -190,7 +226,7 @@ void ClGemmMatrixMultiplyNativeKernel::configure(const CLCompileContext &compile
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info));
- auto padding_info = get_padding_info({ src0, src1, src2, dst });
+ auto padding_info = get_padding_info({ src0, dst });
_reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d;
_reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0;
_use_dummy_work_items = preferred_dummy_work_items_support(CLKernelLibrary::get().get_device());