aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp48
1 files changed, 32 insertions, 16 deletions
diff --git a/src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp b/src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp
index bf4b664b6e..eea2a169a3 100644
--- a/src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp
@@ -31,6 +31,7 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/StringUtils.h"
+
#include "src/core/AccessWindowStatic.h"
#include "src/core/CL/CLValidate.h"
#include "src/core/helpers/AutoConfiguration.h"
@@ -46,13 +47,17 @@ namespace kernels
{
namespace
{
-Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d)
+Status validate_arguments(const ITensorInfo *src,
+ const ITensorInfo *dst,
+ const GEMMLHSMatrixInfo &lhs_info,
+ bool reinterpret_input_as_3d)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 == 0);
ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.k0 == 0);
ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.v0 == 0);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(((lhs_info.k0 & (lhs_info.k0 - 1)) && lhs_info.k0 != 3), "Only 2,3,4,8,16 are supported for k0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(((lhs_info.k0 & (lhs_info.k0 - 1)) && lhs_info.k0 != 3),
+ "Only 2,3,4,8,16 are supported for k0");
ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.k0 > 16);
ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 2 || lhs_info.m0 > 8);
ARM_COMPUTE_RETURN_ERROR_ON((lhs_info.m0 > 4 && lhs_info.m0 < 8) && lhs_info.transpose);
@@ -60,10 +65,11 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src);
ARM_COMPUTE_RETURN_ERROR_ON(src->data_type() == DataType::UNKNOWN);
- if(dst->total_size() != 0)
+ if (dst->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(),
- misc::shape_calculator::compute_lhs_reshaped_shape(*src, lhs_info, reinterpret_input_as_3d));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(
+ dst->tensor_shape(),
+ misc::shape_calculator::compute_lhs_reshaped_shape(*src, lhs_info, reinterpret_input_as_3d));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
}
@@ -71,14 +77,15 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const
return Status{};
}
-Window configure_window(ITensorInfo *src, ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d)
+Window
+configure_window(ITensorInfo *src, ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d)
{
const unsigned int num_elems_processed_per_iteration_x = lhs_info.k0;
const unsigned int num_elems_processed_per_iteration_y = lhs_info.m0;
TensorInfo tmp_info(*src);
- if(reinterpret_input_as_3d)
+ if (reinterpret_input_as_3d)
{
// Since the src tensor has to be reinterpreted as 3D and the execute window is based on a 2D interleave,
// the window needs to be constructed on the 2D collapsed version of the tensor
@@ -88,10 +95,12 @@ Window configure_window(ITensorInfo *src, ITensorInfo *dst, const GEMMLHSMatrixI
}
// dst auto inizialitation if not yet initialized
- auto_init_if_empty(*dst, src->clone()->set_tensor_shape(misc::shape_calculator::compute_lhs_reshaped_shape(*src, lhs_info, reinterpret_input_as_3d)));
+ auto_init_if_empty(*dst, src->clone()->set_tensor_shape(misc::shape_calculator::compute_lhs_reshaped_shape(
+ *src, lhs_info, reinterpret_input_as_3d)));
// Configure window
- Window win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ Window win =
+ calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
// Collapse along the Z direction
// This collapse needs to be here in order to tune the Z dimension of LWS
@@ -106,14 +115,18 @@ ClGemmReshapeLhsMatrixKernel::ClGemmReshapeLhsMatrixKernel()
_type = CLKernelType::ELEMENTWISE;
}
-void ClGemmReshapeLhsMatrixKernel::configure(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d)
+void ClGemmReshapeLhsMatrixKernel::configure(const CLCompileContext &compile_context,
+ ITensorInfo *src,
+ ITensorInfo *dst,
+ const GEMMLHSMatrixInfo &lhs_info,
+ bool reinterpret_input_as_3d)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
// Perform validate step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst, lhs_info, reinterpret_input_as_3d));
- auto padding_info = get_padding_info({ src });
+ auto padding_info = get_padding_info({src});
const unsigned int src_w = src->dimension(0);
const unsigned int m = reinterpret_input_as_3d ? src->dimension(1) * src->dimension(2) : src->dimension(1);
@@ -168,7 +181,10 @@ void ClGemmReshapeLhsMatrixKernel::configure(const CLCompileContext &compile_con
ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info));
}
-Status ClGemmReshapeLhsMatrixKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, bool reinterpret_input_as_3d)
+Status ClGemmReshapeLhsMatrixKernel::validate(const ITensorInfo *src,
+ const ITensorInfo *dst,
+ const GEMMLHSMatrixInfo &lhs_info,
+ bool reinterpret_input_as_3d)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, dst, lhs_info, reinterpret_input_as_3d));
return Status{};
@@ -179,8 +195,9 @@ void ClGemmReshapeLhsMatrixKernel::run_op(ITensorPack &tensors, const Window &wi
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
- const auto src = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC));
- auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
+ const auto src =
+ utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC));
+ auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
@@ -192,8 +209,7 @@ void ClGemmReshapeLhsMatrixKernel::run_op(ITensorPack &tensors, const Window &wi
add_3d_tensor_nhw_argument(idx, src);
add_3d_tensor_nhw_argument(idx, dst);
enqueue(queue, *this, slice, lws_hint());
- }
- while(window.slide_window_slice_3D(slice));
+ } while (window.slide_window_slice_3D(slice));
}
} // namespace kernels
} // namespace opencl