diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp index 5a9ff7990f..099a2c980f 100644 --- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp @@ -40,7 +40,7 @@ #include "src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h" #include "src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h" #include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h" -#include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h" +#include "src/core/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h" #include "utils/TypePrinter.h" @@ -127,7 +127,7 @@ inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs TensorInfo tmp_b_info{}; // Validate reshape RHS kernel auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); - if(!bool(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info))) + if(!bool(opencl::kernels::ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info))) { return false; } @@ -192,7 +192,7 @@ CLGEMMLowpMatrixMultiplyCore::CLGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemo _weights_to_qasymm8(std::make_unique<CLDepthConvertLayerKernel>()), _mm_native_kernel(std::make_unique<CLGEMMLowpMatrixMultiplyNativeKernel>()), _mm_reshaped_only_rhs_kernel(std::make_unique<CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel>()), - _mtx_b_reshape_kernel(std::make_unique<CLGEMMReshapeRHSMatrixKernel>()), + _mtx_b_reshape_kernel(std::make_unique<opencl::kernels::ClGemmReshapeRhsMatrixKernel>()), _mtx_a_reduction_kernel(std::make_unique<CLGEMMLowpMatrixAReductionKernel>()), _mtx_b_reduction_kernel(std::make_unique<CLGEMMLowpMatrixBReductionKernel>()), _offset_contribution_kernel(std::make_unique<CLGEMMLowpOffsetContributionKernel>()), @@ -292,7 +292,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con a->info(), _convert_to_qasymm8 ? _qasymm8_weights.info() : b->info(), output->info()); // Configure reshape RHS kernel - _mtx_b_reshape_kernel->configure(compile_context, _convert_to_qasymm8 ? &_qasymm8_weights : b, &_tmp_b, rhs_info); + _mtx_b_reshape_kernel->configure(compile_context, _convert_to_qasymm8 ? _qasymm8_weights.info() : b->info(), _tmp_b.info(), rhs_info); } // Using default reduction info @@ -496,7 +496,7 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso // Validate reshape RHS kernel auto_init_if_empty(tmp_b_info, weights_info.clone()->set_tensor_shape(compute_rhs_reshaped_shape(weights_info, rhs_info))); - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(&weights_info, &tmp_b_info, rhs_info)); + ARM_COMPUTE_RETURN_ON_ERROR(opencl::kernels::ClGemmReshapeRhsMatrixKernel::validate(&weights_info, &tmp_b_info, rhs_info)); } TensorInfo info_vector_sum_col{}; @@ -634,6 +634,9 @@ void CLGEMMLowpMatrixMultiplyCore::run() if(!_reshape_b_only_on_first_run) { // Run reshape matrix B + ITensorPack mtx_b_pack; + mtx_b_pack.add_const_tensor(TensorType::ACL_SRC, _convert_to_qasymm8 ? &_qasymm8_weights : _original_b); + mtx_b_pack.add_tensor(TensorType::ACL_DST, &_tmp_b); CLScheduler::get().enqueue(*_mtx_b_reshape_kernel, false); } } @@ -687,7 +690,10 @@ void CLGEMMLowpMatrixMultiplyCore::prepare() // Run reshape kernel and mark original weights tensor as unused _tmp_b.allocator()->allocate(); - CLScheduler::get().enqueue(*_mtx_b_reshape_kernel, false); + ITensorPack mtx_b_pack; + mtx_b_pack.add_const_tensor(TensorType::ACL_SRC, _convert_to_qasymm8 ? &_qasymm8_weights : _original_b); + mtx_b_pack.add_tensor(TensorType::ACL_DST, &_tmp_b); + CLScheduler::get().enqueue_op(*_mtx_b_reshape_kernel, mtx_b_pack, false); _original_b->mark_as_unused(); } |