aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp42
1 files changed, 39 insertions, 3 deletions
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
index 3a39128c0a..4b28e2badc 100644
--- a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp
@@ -34,6 +34,7 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "src/core/CL/CLUtils.h"
#include "src/core/CL/CLValidate.h"
+#include "src/core/experimental/PostOp.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
#include "src/core/utils/helpers/float_ops.h"
@@ -51,6 +52,16 @@ namespace
{
using ElementsProcessed = Steps;
+const auto post_op_utils = experimental::PostOpCLKernelUtils(
+{
+ // PostOp sequence -> {Kernel Postfix, PostOp Slots}
+ { {}, { "", {} } },
+ { { experimental::PostOpType::Activation }, { "", { 1 } } },
+ { { experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 2 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add }, { "_post_act_eltwise_op_act", { 1, 2 } } },
+ { { experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 2, 3 } } },
+ { { experimental::PostOpType::Activation, experimental::PostOpType::Eltwise_Add, experimental::PostOpType::Activation }, { "_post_act_eltwise_op_act", { 1, 2, 3 } } }
+});
Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
const GEMMRHSMatrixInfo &rhs_info,
const GEMMKernelInfo &gemm_info)
@@ -74,6 +85,7 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
"Bias addition only supported with broadcast mode in case the input or dst has to be reinterpreted as 3D");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision && (src0->data_type() == DataType::F32), "Mixed precision only supported for F16 data type");
ARM_COMPUTE_RETURN_ON_ERROR(gemm::validate_image2d_support_on_rhs(*src1, rhs_info));
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!post_op_utils.is_post_op_sequence_supported(gemm_info.post_ops), "The sequence of Post Ops is not supported");
const unsigned int m = gemm_info.m;
const unsigned int n = gemm_info.n;
@@ -117,6 +129,7 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
const TensorInfo tensor_info_dst = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_dst);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!post_op_utils.are_post_op_shapes_compliant(dst, gemm_info.post_ops), "The Post Op shapes are not compliant");
}
return Status{};
@@ -180,6 +193,7 @@ void ClGemmMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi
_add_bias = src2 != nullptr;
_export_to_cl_image = rhs_info.export_to_cl_image;
_k = gemm_info.k;
+ _num_post_op_args = gemm_info.post_ops.total_num_arguments();
// Check if we need to slide the matrix B
const unsigned int num_dimensions_src0 = src0->num_dimensions();
@@ -222,9 +236,6 @@ void ClGemmMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi
build_opts.add_option_if(rhs_info.interleave, "-DRHS_INTERLEAVE");
build_opts.add_option_if(lhs_info.transpose, "-DLHS_TRANSPOSE");
build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS");
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
build_opts.add_option_if(enable_mixed_precision, "-DMIXED_PRECISION");
build_opts.add_option_if(rhs_info.export_to_cl_image, "-DOPENCL_IMAGE_SUPPORT");
build_opts.add_option("-DRHS_HEIGHT=" + support::cpp11::to_string(src1->dimension(1)));
@@ -240,11 +251,23 @@ void ClGemmMatrixMultiplyReshapedKernel::configure(const CLCompileContext &compi
build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
build_opts.add_option("-DPARTIAL_STORE_M0=" + support::cpp11::to_string(partial_store_m0));
build_opts.add_option("-DPARTIAL_STORE_N0=" + support::cpp11::to_string(partial_store_n0));
+ // If post_ops are used, then we disable the use of gemm_info.activation_info
+ if(gemm_info.post_ops.size() > 0)
+ {
+ post_op_utils.set_post_ops_cl_build_options(build_opts, gemm_info.post_ops);
+ }
+ else
+ {
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
+ }
std::string kernel_name("gemm_mm_reshaped_");
kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_";
kernel_name += rhs_info.transpose ? "rhs_t" : "rhs_nt";
kernel_name += rhs_info.export_to_cl_image ? "_texture" : "";
+ post_op_utils.set_post_ops_cl_kernel_name(kernel_name, gemm_info.post_ops);
// Create kernel
_kernel = create_kernel(compile_context, kernel_name, build_opts.options());
@@ -360,6 +383,13 @@ void ClGemmMatrixMultiplyReshapedKernel::run_op(ITensorPack &tensors, const Wind
// dst buffer
add_2D_tensor_argument(idx, dst, slice);
+ // post op argument buffers
+ for(size_t i = 0; i < _num_post_op_args; ++i)
+ {
+ const auto post_op_arg = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(experimental::get_post_op_arg_type(i)));
+ add_2D_tensor_argument(idx, post_op_arg, slice);
+ }
+
// K dimension (not used if _export_to_cl_image == true)
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_k));
@@ -378,6 +408,12 @@ void ClGemmMatrixMultiplyReshapedKernel::run_op(ITensorPack &tensors, const Wind
// dst stride_z
_kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(dst->info()->strides_in_bytes()[2]));
+ // post op argument stride_z
+ for(size_t i = 0; i < _num_post_op_args; ++i)
+ {
+ const auto post_op_arg = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(experimental::get_post_op_arg_type(i)));
+ _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(post_op_arg->info()->strides_in_bytes()[2]));
+ }
// Cross-plan padding (if _reinterpret_output_as_3d = true)
if(_reinterpret_output_as_3d)
{