aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/operators/ClGemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/operators/ClGemm.cpp')
-rw-r--r--src/gpu/cl/operators/ClGemm.cpp15
1 files changed, 12 insertions, 3 deletions
diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp
index d2d0f8f91d..e05256ee2f 100644
--- a/src/gpu/cl/operators/ClGemm.cpp
+++ b/src/gpu/cl/operators/ClGemm.cpp
@@ -38,6 +38,7 @@
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "arm_compute/runtime/ITensorAllocator.h"
+#include "arm_compute/core/experimental/IPostOp.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/MemoryHelpers.h"
#include "src/core/utils/helpers/float_ops.h"
@@ -64,7 +65,7 @@ namespace
{
inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
{
- return kernel_type == CLGEMMKernelType::NATIVE? false : true;
+ return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
}
//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
@@ -203,6 +204,7 @@ ClGemm::ClGemm()
void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
const GEMMInfo &gemm_info)
{
+ ARM_COMPUTE_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel");
DataType data_type = a->data_type();
bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
@@ -252,6 +254,7 @@ void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensor
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = broadcast_bias;
kernel_info.activation_info = gemm_info.activation_info();
+ kernel_info.post_ops = gemm_info.post_ops();
// Set the target for the kernels
_reshape_lhs_kernel->set_target(gpu_target);
@@ -278,6 +281,7 @@ void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensor
void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
const GEMMInfo &gemm_info)
{
+ ARM_COMPUTE_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel");
DataType data_type = a->data_type();
bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
@@ -330,6 +334,7 @@ Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const
{
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel");
// Get the GPU target
const GPUTarget gpu_target = CLScheduler::get().target();
@@ -386,6 +391,7 @@ Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, con
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = broadcast_bias;
kernel_info.activation_info = gemm_info.activation_info();
+ kernel_info.post_ops = gemm_info.post_ops();
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
@@ -412,6 +418,7 @@ Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf
{
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel");
TensorInfo tmp_b_info{};
@@ -588,8 +595,10 @@ void ClGemm::run(ITensorPack &tensors)
ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
}
-
- ITensorPack gemm_reshaped_pack{ { ACL_SRC_0, lhs_reshaped.get() }, { ACL_SRC_1, rhs_reshaped.get() }, { ACL_SRC_2, src2 }, { ACL_DST, dst } };
+ // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
+ ITensorPack gemm_reshaped_pack(tensors);
+ gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
+ gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
{