diff options
Diffstat (limited to 'src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp')
-rw-r--r-- | src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp | 271 |
1 files changed, 137 insertions, 134 deletions
diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp b/src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp index 390bb97665..4270165ab4 100644 --- a/src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp +++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp @@ -25,7 +25,8 @@ #include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/CLKernelLibrary.h" -#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h" + +#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" #include <map> #include <utility> @@ -34,8 +35,7 @@ namespace arm_compute { namespace cl_gemm { -CLGEMMDefaultTypeBifrost::CLGEMMDefaultTypeBifrost(GPUTarget gpu) - : ICLGEMMKernelSelection(gpu) +CLGEMMDefaultTypeBifrost::CLGEMMDefaultTypeBifrost(GPUTarget gpu) : ICLGEMMKernelSelection(gpu) { } @@ -44,131 +44,133 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::select_kernel(const CLGEMMKernelSelec // _target could be used in the future to have a dedicated heuristic for each GPU IP ARM_COMPUTE_UNUSED(_target); - using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMDefaultTypeBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMDefaultTypeBifrost::*)( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); // Default configurations for Bifrost architectures - static std::map<DataType, FunctionExecutorPtr> gemm_default_configs = - { - { DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32 }, - { DataType::F16, &CLGEMMDefaultTypeBifrost::default_f16 }, - { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 } - }; + static std::map<DataType, FunctionExecutorPtr> gemm_default_configs = { + {DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32}, + {DataType::F16, &CLGEMMDefaultTypeBifrost::default_f16}, + {DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8}}; // Mali-G71 configurations - static std::map<DataType, FunctionExecutorPtr> gemm_g71_configs = - { - { DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32 }, - { DataType::F16, &CLGEMMDefaultTypeBifrost::g71_f16 }, - { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 } - }; + static std::map<DataType, FunctionExecutorPtr> gemm_g71_configs = { + {DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32}, + {DataType::F16, &CLGEMMDefaultTypeBifrost::g71_f16}, + {DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8}}; // Mali-G52 configurations - static std::map<DataType, FunctionExecutorPtr> gemm_g52_configs = - { - { DataType::F32, &CLGEMMDefaultTypeBifrost::g52_f32 }, - { DataType::F16, &CLGEMMDefaultTypeBifrost::g52_f16 }, - { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 } - }; + static std::map<DataType, FunctionExecutorPtr> gemm_g52_configs = { + {DataType::F32, &CLGEMMDefaultTypeBifrost::g52_f32}, + {DataType::F16, &CLGEMMDefaultTypeBifrost::g52_f16}, + {DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8}}; // Mali-G76 configurations - static std::map<DataType, FunctionExecutorPtr> gemm_g76_configs = - { - { DataType::F32, &CLGEMMDefaultTypeBifrost::g76_f32 }, - { DataType::F16, &CLGEMMDefaultTypeBifrost::g76_f16 }, - { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 }, - { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 } - }; + static std::map<DataType, FunctionExecutorPtr> gemm_g76_configs = { + {DataType::F32, &CLGEMMDefaultTypeBifrost::g76_f32}, + {DataType::F16, &CLGEMMDefaultTypeBifrost::g76_f16}, + {DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8}, + {DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8}}; const DataType data_type = params.data_type; - switch(_target) + switch (_target) { case GPUTarget::G71: - if(gemm_g71_configs.find(data_type) != gemm_g71_configs.end()) + if (gemm_g71_configs.find(data_type) != gemm_g71_configs.end()) { - return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.b, + params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); case GPUTarget::G76: - if(gemm_g76_configs.find(data_type) != gemm_g76_configs.end()) + if (gemm_g76_configs.find(data_type) != gemm_g76_configs.end()) { - return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, + params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); case GPUTarget::G52: - if(gemm_g52_configs.find(data_type) != gemm_g52_configs.end()) + if (gemm_g52_configs.find(data_type) != gemm_g52_configs.end()) { - return (this->*gemm_g52_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + return (this->*gemm_g52_configs[data_type])(params.m, params.n, params.k, params.b, + params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); default: - if(gemm_default_configs.find(data_type) != gemm_default_configs.end()) + if (gemm_default_configs.find(data_type) != gemm_default_configs.end()) { - return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, + params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); } } -CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(b); - CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE_V1; + CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE; - if(is_rhs_constant) + if (is_rhs_constant) { - if((m > 1) && (n < 16)) + if ((m > 1) && (n < 16)) { - gemm_type = CLGEMMKernelType::RESHAPED_V1; + gemm_type = CLGEMMKernelType::RESHAPED; } - else if(m == 1) + else if (m == 1) { gemm_type = CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if((k > 256) && (m > 4)) + if ((k > 256) && (m > 4)) { constexpr float alpha = 3.2f; constexpr float fact0 = 1.51f; constexpr float fact1 = 1.66f; constexpr float ops = 12.0f; const float scale = k > 1024 ? 1.07f : 1.0f; - gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? CLGEMMKernelType::RESHAPED_V1 : CLGEMMKernelType::NATIVE_V1; + gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) + ? CLGEMMKernelType::RESHAPED + : CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - gemm_type = CLGEMMKernelType::NATIVE_V1; + gemm_type = CLGEMMKernelType::RESHAPED_ONLY_RHS; } } const auto workload = static_cast<float>((m * n) / 20.0f); - gemm_type = ((workload > 1600.0f) && (gemm_type == CLGEMMKernelType::RESHAPED_V1)) ? CLGEMMKernelType::RESHAPED : gemm_type; + gemm_type = ((workload > 1600.0f) && (gemm_type == CLGEMMKernelType::RESHAPED)) ? CLGEMMKernelType::RESHAPED + : gemm_type; } return gemm_type; } -CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(n, k, b); - if(is_rhs_constant) + if (is_rhs_constant) { - if(m == 1) + if (m == 1) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -179,15 +181,16 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f16(unsigned int m, unsigned } else { - return CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } } -CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_q8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(m, n, k, b); - if(is_rhs_constant) + if (is_rhs_constant) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -197,21 +200,22 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_q8(unsigned int m, unsigned i } } -CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +CLGEMMKernelType +CLGEMMDefaultTypeBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(b); - if(!is_rhs_constant) + if (!is_rhs_constant) { - return CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } - if(m == 1) + if (m == 1) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } - if(k <= 496) + if (k <= 496) { - if(n <= 544) + if (n <= 544) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -222,17 +226,17 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f32(unsigned int m, unsigned int } else { - if(k <= 588) + if (k <= 588) { - if(k <= 552) + if (k <= 552) { - if(m <= 148) + if (m <= 148) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if(m <= 278) + if (m <= 278) { return CLGEMMKernelType::RESHAPED; } @@ -254,16 +258,17 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f32(unsigned int m, unsigned int } } -CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +CLGEMMKernelType +CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(b); - if(!is_rhs_constant) + if (!is_rhs_constant) { - return CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } - if(m == 1) + if (m == 1) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -273,13 +278,13 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int const float r_nk = static_cast<float>(n) / static_cast<float>(k); const float r_mnk = static_cast<float>(m) / (static_cast<float>(n) * static_cast<float>(k)); - if(r_mn <= 1.5469f) + if (r_mn <= 1.5469f) { - if(r_mk <= 0.8766f) + if (r_mk <= 0.8766f) { - if(r_mk <= 0.0211f) + if (r_mk <= 0.0211f) { - if(r_mnk <= 77.5833f) + if (r_mnk <= 77.5833f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -290,7 +295,7 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int } else { - if(r_nk <= 0.0832f) + if (r_nk <= 0.0832f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -302,11 +307,11 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int } else { - if(r_mnk <= 193.0000f) + if (r_mnk <= 193.0000f) { - if(r_mn <= 0.9948f) + if (r_mn <= 0.9948f) { - if(r_mk <= 2.5453f) + if (r_mk <= 2.5453f) { return CLGEMMKernelType::RESHAPED; } @@ -328,17 +333,17 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int } else { - if(r_mn <= 17.7370f) + if (r_mn <= 17.7370f) { - if(r_mnk <= 1391.2875f) + if (r_mnk <= 1391.2875f) { - if(r_mk <= 2.9724f) + if (r_mk <= 2.9724f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if(r_mnk <= 470.0000f) + if (r_mnk <= 470.0000f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -350,9 +355,9 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int } else { - if(r_nk <= 0.1381f) + if (r_nk <= 0.1381f) { - if(r_mnk <= 9040.5000f) + if (r_mnk <= 9040.5000f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -363,7 +368,7 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int } else { - if(r_mn <= 5.6790f) + if (r_mn <= 5.6790f) { return CLGEMMKernelType::RESHAPED; } @@ -381,16 +386,17 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int } } -CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +CLGEMMKernelType +CLGEMMDefaultTypeBifrost::g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(b); - if(!is_rhs_constant) + if (!is_rhs_constant) { - return CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } - if(m == 1) + if (m == 1) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -398,21 +404,21 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f16(unsigned int m, unsigned int const float r_mn = static_cast<float>(m) / static_cast<float>(n); const float r_nk = static_cast<float>(n) / static_cast<float>(k); - if(k <= 212) + if (k <= 212) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if(r_nk <= 0.4990234375f) + if (r_nk <= 0.4990234375f) { - if(k <= 1392) + if (k <= 1392) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if(m <= 325) + if (m <= 325) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -424,13 +430,13 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f16(unsigned int m, unsigned int } else { - if(k <= 471) + if (k <= 471) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if(r_mn <= 0.04475911520421505f) + if (r_mn <= 0.04475911520421505f) { return CLGEMMKernelType::RESHAPED; } @@ -443,37 +449,38 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f16(unsigned int m, unsigned int } } -CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +CLGEMMKernelType +CLGEMMDefaultTypeBifrost::g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - if(!is_rhs_constant) + if (!is_rhs_constant) { - return CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } - if(m == 1) + if (m == 1) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } - if(n <= 127.0000f) + if (n <= 127.0000f) { - if(n <= 63.5000f) + if (n <= 63.5000f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if(m <= 3616.0000f) + if (m <= 3616.0000f) { - if(b <= 18.5000f) + if (b <= 18.5000f) { - if(m <= 2970.5000f) + if (m <= 2970.5000f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if(k <= 104.0000f) + if (k <= 104.0000f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -496,19 +503,19 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f16(unsigned int m, unsigned int } else { - if(m <= 12.5000f) + if (m <= 12.5000f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - if(k <= 104.0000f) + if (k <= 104.0000f) { - if(b <= 18.5000f) + if (b <= 18.5000f) { - if(m <= 490.0000f) + if (m <= 490.0000f) { - if(n <= 272.0000f) + if (n <= 272.0000f) { return CLGEMMKernelType::RESHAPED_ONLY_RHS; } @@ -529,11 +536,11 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f16(unsigned int m, unsigned int } else { - if(m <= 226.0000f) + if (m <= 226.0000f) { - if(n <= 140.0000f) + if (n <= 140.0000f) { - if(m <= 179.5000f) + if (m <= 179.5000f) { return CLGEMMKernelType::RESHAPED; } @@ -556,22 +563,18 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f16(unsigned int m, unsigned int } } -CLGEMMKernelType CLGEMMDefaultTypeBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +CLGEMMKernelType +CLGEMMDefaultTypeBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(b); + ARM_COMPUTE_UNUSED(n); + ARM_COMPUTE_UNUSED(k); - if(is_rhs_constant) + if (is_rhs_constant) { - if(m == 1) + if (m == 1) { - if(n > k) - { - return CLGEMMKernelType::NATIVE_V1; - } - else - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } + return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { @@ -580,7 +583,7 @@ CLGEMMKernelType CLGEMMDefaultTypeBifrost::g71_f16(unsigned int m, unsigned int } else { - return CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } } } // namespace cl_gemm |