diff options
Diffstat (limited to 'src/core/CL/gemm/reshaped_only_rhs')
-rw-r--r-- | src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp | 64 | ||||
-rw-r--r-- | src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h | 1 |
2 files changed, 65 insertions, 0 deletions
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp index 3105db6693..188ba4d9c5 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp @@ -61,6 +61,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 } }; + // Configurations for Mali-G52 + static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 = + { + { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32 }, + { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 }, + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 } + }; + // Configurations for Mali-G76 static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 = { @@ -94,6 +105,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi { ARM_COMPUTE_ERROR("Not supported data type"); } + case GPUTarget::G52: + if(gemm_configs_G52.find(data_type) != gemm_configs_G52.end()) + { + return (this->*gemm_configs_G52[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } case GPUTarget::G51: if(gemm_configs_G51.find(data_type) != gemm_configs_G51.end()) { @@ -201,6 +221,50 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi } } +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; + const float r_nk = static_cast<float>(n) / static_cast<float>(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if(m == 1) + { + if(r_nk <= 0.4664f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(workload <= 274.4000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } +} + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h index 618dbd9923..3dfd96a822 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h @@ -46,6 +46,7 @@ public: private: std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); |