diff options
Diffstat (limited to 'src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp')
-rw-r--r-- | src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp | 54 |
1 files changed, 29 insertions, 25 deletions
diff --git a/src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp b/src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp index 95a4d2bd69..97a1298b0a 100644 --- a/src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp +++ b/src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeValhall.cpp @@ -26,6 +26,7 @@ #include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/CLKernelLibrary.h" #include "arm_compute/core/GPUTarget.h" + #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" #include <utility> @@ -38,37 +39,38 @@ namespace kernels { namespace gemm { -ClGemmDefaultConfigNativeValhall::ClGemmDefaultConfigNativeValhall(GPUTarget gpu) - : IClGemmKernelConfig(gpu) +ClGemmDefaultConfigNativeValhall::ClGemmDefaultConfigNativeValhall(GPUTarget gpu) : IClGemmKernelConfig(gpu) { } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall::configure( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigNativeValhall::*)(unsigned int m, unsigned int n, unsigned int k, - unsigned int b); + using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ( + ClGemmDefaultConfigNativeValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_default(&ClGemmDefaultConfigNativeValhall::configure_G77_f32, - &ClGemmDefaultConfigNativeValhall::configure_G77_f16, - &ClGemmDefaultConfigNativeValhall::configure_G77_u8); + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_default( + &ClGemmDefaultConfigNativeValhall::configure_G77_f32, &ClGemmDefaultConfigNativeValhall::configure_G77_f16, + &ClGemmDefaultConfigNativeValhall::configure_G77_u8); auto func = configs_default.get_function(data_type); ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM"); return (this->*func)(m, n, k, b); } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> +ClGemmDefaultConfigNativeValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { - if(n < 2048) + if (n < 2048) { return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false); } - else if(n >= 2048 && n < 8192) + else if (n >= 2048 && n < 8192) { return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false); } @@ -83,18 +85,19 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> +ClGemmDefaultConfigNativeValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { - if(n < 2048) + if (n < 2048) { return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 1, false, false, false, false); } - else if(n >= 2048 && n < 8192) + else if (n >= 2048 && n < 8192) { return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 1, false, false, false, false); } @@ -109,20 +112,21 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> +ClGemmDefaultConfigNativeValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(dot8_supported(CLKernelLibrary::get().get_device())) + if (dot8_supported(CLKernelLibrary::get().get_device())) { - if(m == 1) + if (m == 1) { - if(n < 2048) + if (n < 2048) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 1, false, false, false, false); } - else if(n >= 2048 && n < 16384) + else if (n >= 2048 && n < 16384) { return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false); } @@ -133,7 +137,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall } else { - if(m < 64) + if (m < 64) { return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 1, false, false, false, false); } @@ -145,9 +149,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall } else { - if(m == 1) + if (m == 1) { - if(n < 8192) + if (n < 8192) { return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 1, false, false, false, false); } @@ -165,4 +169,4 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigNativeValhall } // namespace gemm } // namespace kernels } // namespace opencl -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |