diff options
Diffstat (limited to 'src/runtime')
-rw-r--r-- | src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp | 122 | ||||
-rw-r--r-- | src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h | 1 |
2 files changed, 122 insertions, 1 deletions
diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp index c77746a044..46d07fffba 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp @@ -72,7 +72,7 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS static std::map<DataType, FunctionExecutorPtr> gemm_g52_configs = { { DataType::F32, &CLGEMMKernelSelectionBifrost::g52_f32 }, - { DataType::F16, &CLGEMMKernelSelectionBifrost::default_f16 }, + { DataType::F16, &CLGEMMKernelSelectionBifrost::g52_f16 }, { DataType::QASYMM8, &CLGEMMKernelSelectionBifrost::default_q8 }, { DataType::QASYMM8_SIGNED, &CLGEMMKernelSelectionBifrost::default_q8 }, { DataType::QSYMM8, &CLGEMMKernelSelectionBifrost::default_q8 }, @@ -443,6 +443,126 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f16(unsigned int m, unsigned } } +CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +{ + ARM_COMPUTE_UNUSED(b); + + if (!is_rhs_constant) + { + return CLGEMMKernelType::NATIVE_V1; + } + + if (m == 1) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float r_mk = static_cast<float>(m) / static_cast<float>(k); + const float r_nk = static_cast<float>(n) / static_cast<float>(k); + const float r_mnk = static_cast<float>(m) / (static_cast<float>(n) * static_cast<float>(k)); + + if(r_mn <= 22.9200f) + { + if(r_mk <= 0.0157f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + if(r_mnk <= 7809.3750f) + { + if(r_mnk <= 101.7937f) + { + if(r_mn <= 0.4594f) + { + if(r_mk <= 0.0557f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } + else + { + if(r_nk <= 0.4396f) + { + if(r_mn <= 1.5182f) + { + if(r_mnk <= 1709.9167f) + { + return CLGEMMKernelType::RESHAPED; + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } + else + { + if(r_mnk <= 1330.6000f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + } + else + { + if(r_mn <= 2.5896f) + { + return CLGEMMKernelType::RESHAPED; + } + else + { + if(r_mnk <= 326.6667f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + } + } + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } + } + else + { + if(r_mn <= 86.7578f) + { + if(r_mnk <= 11231.6406f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } +} + CLGEMMKernelType CLGEMMKernelSelectionBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(b); diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h index fbafc531f5..6831a12aec 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h @@ -47,6 +47,7 @@ private: CLGEMMKernelType g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); |