diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index e81d8a6b97..9867229a7c 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -39,7 +39,7 @@ using namespace arm_compute; CLGEMM::CLGEMM() - : _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _ma_kernel(), _tmp_a(), _tmp_b(), _run_vector_matrix_multiplication(false), _run_addition(false) + : _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _ma_kernel(), _tmp_a(), _tmp_b(), _is_interleaved_transposed(false), _run_addition(false) { } @@ -59,12 +59,16 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(0) != b->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B"); - _mm_kernel.set_target(CLScheduler::get().target()); + // If the input tensor has less than 16 rows, we run a special version of GEMM without reshaping the input tensors + _is_interleaved_transposed = a->info()->dimension(1) > 16; - // Check if the first input tensor is a vector. If so, all the kernels for reshaping the tensors can be skipped - if(a->info()->dimension(1) != 1) + const ICLTensor *matrix_a = a; + const ICLTensor *matrix_b = b; + + if(_is_interleaved_transposed) { - _run_vector_matrix_multiplication = false; + matrix_a = &_tmp_a; + matrix_b = &_tmp_b; TensorShape shape_tmp_a = a->info()->tensor_shape(); TensorShape shape_tmp_b = b->info()->tensor_shape(); @@ -89,19 +93,17 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * _transpose_kernel.configure(b, &_tmp_b); // Configure matrix multiply kernel - _mm_kernel.configure(&_tmp_a, &_tmp_b, output, alpha); + _mm_kernel.set_target(CLScheduler::get().target()); + } + _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed); + + if(_is_interleaved_transposed) + { // Allocate intermediate tensors _tmp_a.allocator()->allocate(); _tmp_b.allocator()->allocate(); } - else // The first input tensor is a vector - { - _run_vector_matrix_multiplication = true; - - // Configure the matrix multiply kernel - _mm_kernel.configure(a, b, output, alpha); - } // Configure matrix addition kernel if(beta != 0 && c != nullptr) @@ -113,7 +115,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * void CLGEMM::run() { - if(!_run_vector_matrix_multiplication) + if(_is_interleaved_transposed) { // Run interleave kernel CLScheduler::get().enqueue(_interleave_kernel, false); |