diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 56 |
1 files changed, 36 insertions, 20 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index e78395f1de..762b00177c 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -65,37 +65,53 @@ CLGEMM::GEMMType CLGEMM::select_gemm_type(unsigned int m, unsigned int n, unsign { GEMMType gemm_type = GEMMType::RESHAPED_V1; - if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76)) + if(gpu_target_is_in(gpu_target, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, + GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, + GPUTarget::G76, GPUTarget::G77)) { - if((m > 1) && (n < 16)) + if(data_type == DataType::F32) { - gemm_type = GEMMType::RESHAPED_V1; - } - else if((m == 1) && (data_type == DataType::F32)) - { - gemm_type = GEMMType::RESHAPED_ONLY_RHS; + if((m > 1) && (n < 16)) + { + gemm_type = GEMMType::RESHAPED_V1; + } + else if(m == 1) + { + gemm_type = GEMMType::RESHAPED_ONLY_RHS; + } + else + { + // COMPMID-852 + if((k > 256) && (m > 4) && reshape_b_only_on_first_run) + { + constexpr float alpha = 3.2f; + constexpr float fact0 = 1.51f; + constexpr float fact1 = 1.66f; + constexpr float ops = 12.0f; + const float scale = k > 1024 ? 1.07f : 1.0f; + gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE; + } + else + { + gemm_type = GEMMType::NATIVE; + } + } + + const auto workload = static_cast<float>((m * n) / 20.0f); + + gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type; } else { - // COMPMID-852 - if((k > 256) && (m > 4) && is_data_type_float(data_type) && reshape_b_only_on_first_run) + if((m == 1) || (!reshape_b_only_on_first_run)) { - constexpr float alpha = 3.2f; - constexpr float fact0 = 1.51f; - constexpr float fact1 = 1.66f; - constexpr float ops = 12.0f; - const float scale = k > 1024 ? 1.07f : 1.0f; - gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE; + gemm_type = GEMMType::NATIVE; } else { - gemm_type = GEMMType::NATIVE; + gemm_type = GEMMType::RESHAPED_V2; } } - - const auto workload = static_cast<float>((m * n) / 20.0f); - - gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type; } else { |