aboutsummaryrefslogtreecommitdiff
path: root/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp')
-rw-r--r--src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp33
1 files changed, 31 insertions, 2 deletions
diff --git a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp
index 7ed6b39f3e..417d540468 100644
--- a/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp
+++ b/src/core/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp
@@ -30,7 +30,6 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "src/core/gpu/cl/kernels/gemm/ClGemmHelpers.h"
-
#include <utility>
namespace arm_compute
@@ -61,6 +60,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16,
&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8);
+
CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32,
&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16,
&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8);
@@ -70,7 +73,6 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
ConfigurationFunctionExecutorPtr func = nullptr;
-
switch(_target)
{
case GPUTarget::G76:
@@ -82,6 +84,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
case GPUTarget::G52:
func = configs_G52.get_function(data_type);
break;
+ case GPUTarget::G31:
+ func = configs_G31.get_function(data_type);
+ break;
default:
func = configs_G7x.get_function(data_type);
break;
@@ -113,6 +118,30 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(m == 1)
+ {
+ const unsigned int h0 = std::max(n / 2, 1U);
+ return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
+ }
+ else
+ {
+ const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
+ if(m >= 28)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
+ }
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);