aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp55
1 files changed, 46 insertions, 9 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
index e389ce5b0c..7ad3d55fe0 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp
@@ -33,6 +33,8 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "src/core/AccessWindowStatic.h"
+#include "src/core/CL/CLUtils.h"
+#include "src/core/experimental/PostOp.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "src/core/utils/helpers/float_ops.h"
@@ -49,6 +51,17 @@ namespace
{
using ElementsProcessed = Steps;
+const auto post_op_utils = experimental::PostOpCLKernelUtils(
+{
+ // PostOp sequence -> {Kernel Postfix, PostOp Slots}
+ { {}, { "", {} } },
+ { { experimental::PostOpType::Activation }, { "", { 1 } } },
+ { { experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 2 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 1, 2 } } },
+ { { experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 2, 3 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } }
+});
+
Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
const GEMMRHSMatrixInfo &rhs_info,
const GEMMKernelInfo &gemm_info)
@@ -68,6 +81,7 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
"Bias addition only supported with broadcast mode in case the input or dst has to be reinterpreted as 3D");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.export_to_cl_image, "Export to CLImage not supported for GEMM native");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!post_op_utils.is_post_op_sequence_supported(gemm_info.post_ops), "The sequence of Post Ops is not supported");
const unsigned int m = gemm_info.m;
const unsigned int n = gemm_info.n;
@@ -110,6 +124,7 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
const TensorInfo tensor_info_dst = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_dst);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!post_op_utils.are_post_op_shapes_compliant(dst, gemm_info.post_ops), "The Post Op shapes are not compliant");
}
return Status{};
@@ -170,16 +185,17 @@ void ClGemmMatrixMultiplyNativeKernel::configure(const CLCompileContext &compile
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info));
-
// dst tensor auto initialization if not yet initialized
auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info));
+
auto padding_info = get_padding_info({ src0, src1, src2, dst });
_reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d;
_reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0;
_use_dummy_work_items = preferred_dummy_work_items_support(CLKernelLibrary::get().get_device());
_add_bias = src2 != nullptr;
+ _num_post_op_args = gemm_info.post_ops.total_num_arguments();
// In case both input and dst have to be reinterpreted as 3D tensors,
// force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
@@ -237,11 +253,20 @@ void ClGemmMatrixMultiplyNativeKernel::configure(const CLCompileContext &compile
build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
build_opts.add_option("-DPARTIAL_STORE_M0=" + support::cpp11::to_string(partial_store_m0));
build_opts.add_option("-DPARTIAL_STORE_N0=" + support::cpp11::to_string(partial_store_n0));
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
+ // If post_ops are used, then we disable the use of gemm_info.activation_info
+ if(gemm_info.post_ops.size() > 0)
+ {
+ post_op_utils.set_post_ops_cl_build_options(build_opts, gemm_info.post_ops);
+ }
+ else
+ {
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
+ }
std::string kernel_name("gemm_mm_native");
+ post_op_utils.set_post_ops_cl_kernel_name(kernel_name, gemm_info.post_ops);
// Create kernel
_kernel = create_kernel(compile_context, kernel_name, build_opts.options());
@@ -323,11 +348,11 @@ void ClGemmMatrixMultiplyNativeKernel::run_op(ITensorPack &tensors, const Window
unsigned int idx0;
if(_add_bias)
{
- idx0 = 4 * num_arguments_per_2D_tensor() + 4;
+ idx0 = (4 + _num_post_op_args) * num_arguments_per_2D_tensor() + (4 + _num_post_op_args);
}
else
{
- idx0 = 3 * num_arguments_per_2D_tensor() + 3;
+ idx0 = (3 + _num_post_op_args) * num_arguments_per_2D_tensor() + (3 + _num_post_op_args);
}
const unsigned int total_cross_plane_pad = src0->info()->padding().top + src0->info()->padding().bottom;
_kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
@@ -339,11 +364,11 @@ void ClGemmMatrixMultiplyNativeKernel::run_op(ITensorPack &tensors, const Window
unsigned int idx0;
if(_add_bias)
{
- idx0 = 4 * num_arguments_per_2D_tensor() + 4 + (_reinterpret_input_as_3d ? 1 : 0);
+ idx0 = (4 + _num_post_op_args) * num_arguments_per_2D_tensor() + 4 + (_reinterpret_input_as_3d ? 1 : 0) + _num_post_op_args;
}
else
{
- idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
+ idx0 = (3 + _num_post_op_args) * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0) + _num_post_op_args;
}
const unsigned int total_cross_plane_pad = dst->info()->padding().top + dst->info()->padding().bottom;
_kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
@@ -367,6 +392,12 @@ void ClGemmMatrixMultiplyNativeKernel::run_op(ITensorPack &tensors, const Window
add_2D_tensor_argument(idx, src2, slice);
}
add_2D_tensor_argument(idx, dst, slice);
+ // post op argument buffers
+ for(size_t i = 0; i < _num_post_op_args; ++i)
+ {
+ const auto post_op_arg = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(experimental::get_post_op_arg_type(i)));
+ add_2D_tensor_argument(idx, post_op_arg, slice);
+ }
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(src0->info()->strides_in_bytes()[2]));
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(src1->info()->strides_in_bytes()[2]));
if(_add_bias)
@@ -374,6 +405,12 @@ void ClGemmMatrixMultiplyNativeKernel::run_op(ITensorPack &tensors, const Window
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(src2->info()->strides_in_bytes()[2]));
}
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(dst->info()->strides_in_bytes()[2]));
+ // post op argument stride_z
+ for(size_t i = 0; i < _num_post_op_args; ++i)
+ {
+ const auto post_op_arg = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(experimental::get_post_op_arg_type(i)));
+ _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(post_op_arg->info()->strides_in_bytes()[2]));
+ }
enqueue(queue, *this, slice, lws_hint(), _use_dummy_work_items);
}
while(window.slide_window_slice_3D(slice));