aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/operators/ClGemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/operators/ClGemm.cpp')
-rw-r--r--src/gpu/cl/operators/ClGemm.cpp134
1 files changed, 134 insertions, 0 deletions
diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp
index 88f6b79b56..4db39a635d 100644
--- a/src/gpu/cl/operators/ClGemm.cpp
+++ b/src/gpu/cl/operators/ClGemm.cpp
@@ -191,6 +191,7 @@ ClGemm::ClGemm()
_mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
_mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
_mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
+ _mm_reshaped_only_rhs_mmul_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>()),
_tmp_a(),
_tmp_b(),
_reshape_b_only_on_first_run(false),
@@ -324,6 +325,53 @@ void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context
_aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
}
+void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
+ const GEMMInfo &gemm_info)
+{
+ DataType data_type = a->data_type();
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+ const unsigned int n = b->dimension(0);
+ const unsigned int k = a->dimension(0);
+ const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
+ const GPUTarget gpu_target = CLScheduler::get().target();
+ bool broadcast_bias = gemm_info.broadcast_bias();
+
+ GEMMKernelInfo kernel_info;
+ kernel_info.m = m;
+ kernel_info.n = n;
+ kernel_info.k = k;
+ kernel_info.depth_output_gemm3d = depth_output_gemm3d;
+ kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
+ kernel_info.broadcast_bias = broadcast_bias;
+ kernel_info.activation_info = gemm_info.activation_info();
+ kernel_info.post_ops = gemm_info.post_ops();
+
+ // Set the target for the kernels
+ _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
+
+ GEMMLHSMatrixInfo lhs_info{};
+ GEMMRHSMatrixInfo rhs_info{};
+
+ // Pick up the GEMM configuration
+ auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+ lhs_info = gemm_config.lhs_info;
+ rhs_info = gemm_config.rhs_info;
+ // Force H0 to 4 in order to use the MMUL extension
+ rhs_info.h0 = 4;
+
+ // Reshape Rhs matrix
+ _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
+
+ // Configure matrix multiply kernel with no y padding support
+ kernel_info.has_pad_y = false;
+ _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
+
+ // Request memory for RHS reshape matrix
+ _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
+}
+
Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
@@ -458,6 +506,54 @@ Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf
return Status{};
}
+Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+ ARM_COMPUTE_UNUSED(alpha);
+ ARM_COMPUTE_UNUSED(output);
+ TensorInfo tmp_b_info{};
+
+ // Get the GPU target
+ const GPUTarget gpu_target = CLScheduler::get().target();
+ const DataType data_type = a->data_type();
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+ const unsigned int n = b->dimension(0);
+ const unsigned int k = a->dimension(0);
+ const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
+ const bool broadcast_bias = gemm_info.broadcast_bias();
+
+ GEMMKernelInfo kernel_info;
+ kernel_info.m = m;
+ kernel_info.n = n;
+ kernel_info.k = k;
+ kernel_info.depth_output_gemm3d = depth_output_gemm3d;
+ kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
+ kernel_info.broadcast_bias = broadcast_bias;
+ kernel_info.activation_info = gemm_info.activation_info();
+ kernel_info.post_ops = gemm_info.post_ops();
+
+ GEMMLHSMatrixInfo lhs_info;
+ GEMMRHSMatrixInfo rhs_info;
+
+ // Pick up the GEMM configuration
+ // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
+ const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+ lhs_info = gemm_config.lhs_info;
+ rhs_info = gemm_config.rhs_info;
+ // Force H0 to 4 in order to use the MMUL extension
+ rhs_info.h0 = 4;
+
+ auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+ ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
+
+ // Validate matrix multiply
+ kernel_info.has_pad_y = false;
+ ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
+
+ return Status{};
+}
+
void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
@@ -501,6 +597,11 @@ void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a,
configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
break;
}
+ case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+ {
+ configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("GEMMType not supported");
@@ -545,6 +646,11 @@ Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
break;
}
+ case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info));
+ break;
+ }
default:
{
ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
@@ -627,6 +733,34 @@ void ClGemm::run(ITensorPack &tensors)
}
break;
}
+ case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+ {
+ if(!_reshape_b_only_on_first_run)
+ {
+ // Run transpose kernel
+ ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
+ CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
+ }
+ // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
+ // Check if the lhs or dst tensors have padding
+ const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
+ const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
+ bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
+
+ // Copy original tensor pack and overwrite rhs with reshaped counterpart
+ ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
+ gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
+
+ if(has_pad_y)
+ {
+ ARM_COMPUTE_ERROR_ON(has_pad_y);
+ }
+ else
+ {
+ CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true);
+ }
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("GEMMType not supported");