diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2021-05-17 13:03:50 +0100 |
---|---|---|
committer | Giorgio Arena <giorgio.arena@arm.com> | 2021-05-20 15:19:39 +0000 |
commit | 4403ed3ed09491686a0b182fa498344b005ca812 (patch) | |
tree | 5a231a71d70a7b3ae2412729d8f6a170b54510f7 /src/runtime/gpu/cl | |
parent | ea8d266515812c4dec936b2153ffd5335873e583 (diff) | |
download | ComputeLibrary-4403ed3ed09491686a0b182fa498344b005ca812.tar.gz |
Add support for dynamic weights in CL FullyConnected layer
Make GEMM use its native version if weights are dynamic. This ensures no reshape gets performed on the weights tensor
Enable dynamic weights tests for the OpenCL backend
Resolve COMPMID-4223
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Change-Id: Iccc4806701772cede23e24df09c786914d00034c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5652
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'src/runtime/gpu/cl')
-rw-r--r-- | src/runtime/gpu/cl/operators/ClGemm.cpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/src/runtime/gpu/cl/operators/ClGemm.cpp b/src/runtime/gpu/cl/operators/ClGemm.cpp index fcbc6d5fba..a80375447d 100644 --- a/src/runtime/gpu/cl/operators/ClGemm.cpp +++ b/src/runtime/gpu/cl/operators/ClGemm.cpp @@ -78,8 +78,13 @@ inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type) } } //Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type -inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run) +inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights) { + if(!constant_weights) + { + return CLGEMMKernelType::NATIVE_V1; + } + auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run); if(bool(gemm_kernel)) { @@ -564,7 +569,8 @@ void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); // Select GEMMType - _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run); + _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run, + gemm_info.constant_weights()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); @@ -613,7 +619,7 @@ Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso { CLScheduler::get().target(), a->data_type(), m, n, k, batch_size, }, - gemm_info.reshape_b_only_on_first_run()); + gemm_info.reshape_b_only_on_first_run(), gemm_info.constant_weights()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); |