From 579ca84bd8ef5a91eded65c4dc5e0b9f7de8bef1 Mon Sep 17 00:00:00 2001 From: SiCongLi Date: Mon, 18 Oct 2021 09:38:33 +0100 Subject: Add PostOp support to GEMM and CLGEMM operators and functions Part 2 * Implement PostOp interface changes * Remove spaces around "=" in TypePrinter Partially resolves COMPMID-4435 Signed-off-by: SiCongLi Change-Id: If1e2280554030a0f635e73339a2e86987f6dc41b Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6484 Tested-by: Arm Jenkins Reviewed-by: Sheri Zhang Comments-Addressed: Arm Jenkins --- src/gpu/cl/operators/ClGemm.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'src/gpu/cl/operators/ClGemm.cpp') diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp index d2d0f8f91d..e05256ee2f 100644 --- a/src/gpu/cl/operators/ClGemm.cpp +++ b/src/gpu/cl/operators/ClGemm.cpp @@ -38,6 +38,7 @@ #include "arm_compute/runtime/CL/CLScheduler.h" #include "arm_compute/runtime/ITensorAllocator.h" +#include "arm_compute/core/experimental/IPostOp.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/MemoryHelpers.h" #include "src/core/utils/helpers/float_ops.h" @@ -64,7 +65,7 @@ namespace { inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type) { - return kernel_type == CLGEMMKernelType::NATIVE? false : true; + return kernel_type == CLGEMMKernelType::NATIVE ? false : true; } //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights) @@ -203,6 +204,7 @@ ClGemm::ClGemm() void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) { + ARM_COMPUTE_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel"); DataType data_type = a->data_type(); bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); @@ -252,6 +254,7 @@ void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensor kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; kernel_info.activation_info = gemm_info.activation_info(); + kernel_info.post_ops = gemm_info.post_ops(); // Set the target for the kernels _reshape_lhs_kernel->set_target(gpu_target); @@ -278,6 +281,7 @@ void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensor void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) { + ARM_COMPUTE_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel"); DataType data_type = a->data_type(); bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); @@ -330,6 +334,7 @@ Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const { ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_UNUSED(output); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel"); // Get the GPU target const GPUTarget gpu_target = CLScheduler::get().target(); @@ -386,6 +391,7 @@ Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, con kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; kernel_info.activation_info = gemm_info.activation_info(); + kernel_info.post_ops = gemm_info.post_ops(); GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; @@ -412,6 +418,7 @@ Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf { ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_UNUSED(output); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel"); TensorInfo tmp_b_info{}; @@ -588,8 +595,10 @@ void ClGemm::run(ITensorPack &tensors) ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } }; CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false); } - - ITensorPack gemm_reshaped_pack{ { ACL_SRC_0, lhs_reshaped.get() }, { ACL_SRC_1, rhs_reshaped.get() }, { ACL_SRC_2, src2 }, { ACL_DST, dst } }; + // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts + ITensorPack gemm_reshaped_pack(tensors); + gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get()); + gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get()); if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED) { -- cgit v1.2.1