aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp30
1 files changed, 16 insertions, 14 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index e81d8a6b97..9867229a7c 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -39,7 +39,7 @@
using namespace arm_compute;
CLGEMM::CLGEMM()
- : _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _ma_kernel(), _tmp_a(), _tmp_b(), _run_vector_matrix_multiplication(false), _run_addition(false)
+ : _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _ma_kernel(), _tmp_a(), _tmp_b(), _is_interleaved_transposed(false), _run_addition(false)
{
}
@@ -59,12 +59,16 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(0) != b->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
- _mm_kernel.set_target(CLScheduler::get().target());
+ // If the input tensor has less than 16 rows, we run a special version of GEMM without reshaping the input tensors
+ _is_interleaved_transposed = a->info()->dimension(1) > 16;
- // Check if the first input tensor is a vector. If so, all the kernels for reshaping the tensors can be skipped
- if(a->info()->dimension(1) != 1)
+ const ICLTensor *matrix_a = a;
+ const ICLTensor *matrix_b = b;
+
+ if(_is_interleaved_transposed)
{
- _run_vector_matrix_multiplication = false;
+ matrix_a = &_tmp_a;
+ matrix_b = &_tmp_b;
TensorShape shape_tmp_a = a->info()->tensor_shape();
TensorShape shape_tmp_b = b->info()->tensor_shape();
@@ -89,19 +93,17 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
_transpose_kernel.configure(b, &_tmp_b);
// Configure matrix multiply kernel
- _mm_kernel.configure(&_tmp_a, &_tmp_b, output, alpha);
+ _mm_kernel.set_target(CLScheduler::get().target());
+ }
+ _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed);
+
+ if(_is_interleaved_transposed)
+ {
// Allocate intermediate tensors
_tmp_a.allocator()->allocate();
_tmp_b.allocator()->allocate();
}
- else // The first input tensor is a vector
- {
- _run_vector_matrix_multiplication = true;
-
- // Configure the matrix multiply kernel
- _mm_kernel.configure(a, b, output, alpha);
- }
// Configure matrix addition kernel
if(beta != 0 && c != nullptr)
@@ -113,7 +115,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
void CLGEMM::run()
{
- if(!_run_vector_matrix_multiplication)
+ if(_is_interleaved_transposed)
{
// Run interleave kernel
CLScheduler::get().enqueue(_interleave_kernel, false);