aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp18
1 files changed, 10 insertions, 8 deletions
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index c0bd85dcb5..c447cb8778 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
@@ -31,7 +32,6 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
-#include "arm_compute/runtime/CL/gemm_reshaped/CLGEMMReshapedConfiguration.h"
namespace arm_compute
{
@@ -122,12 +122,12 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
}
// Pick up the GEMM configuration
- std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, DataType::QASYMM8);
+ std::tie(lhs_info, rhs_info) = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
- // Configure interleave kernel
+ // Configure reshape LHS kernel
_mtx_a_reshape_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
- // Configure transpose kernel
+ // Configure reshape RHS kernel
_mtx_b_reshape_kernel.configure(b, &_tmp_b, rhs_info);
}
@@ -236,6 +236,9 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
GEMMRHSMatrixInfo rhs_info;
GEMMLHSMatrixInfo lhs_info;
+ // Get the GPU target
+ const GPUTarget gpu_target = CLScheduler::get().target();
+
bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
const unsigned int n = b->dimension(0);
@@ -259,14 +262,13 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
matrix_b_info = &tmp_b_info;
// Pick up the GEMM configuration
- std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, DataType::QASYMM8);
+ std::tie(lhs_info, rhs_info) = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
- // Validate interleave kernel
+ // Validate reshape LHS kernel
auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
- // Validate transpose kernel
-
+ // Validate reshape RHS kernel
auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
}