diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 1f4df4f1a9..1d1b17bbf1 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -102,7 +102,8 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * // Arguments used by GEMMReshapeInfo // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo // in order to know how the matrices have been reshaped - const int m = a->info()->dimension(1); + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1); const int n = b->info()->dimension(0); const int k = a->info()->dimension(0); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); @@ -118,6 +119,12 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * // Check if we need to reshape the matrix A and matrix B _is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target); + // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D + if(_is_interleaved_transposed) + { + reinterpret_input_as_3d = false; + } + if(_is_interleaved_transposed) { matrix_a = &_tmp_a; @@ -132,14 +139,15 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel // Configure interleave kernel - _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height); + _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()); // Configure transpose kernel _transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width); } // Configure and tune matrix multiply kernel - _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d)); + _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, + reinterpret_input_as_3d)); CLScheduler::get().tune_kernel_static(_mm_kernel); if(_is_interleaved_transposed) @@ -180,11 +188,13 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso // Arguments used by GEMMReshapeInfo // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo // in order to know how the matrices have been reshaped - const int m = a->dimension(1); + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); const int n = b->dimension(0); const int k = a->dimension(0); int mult_transpose1xW_width = 1; int mult_interleave4x4_height = 1; + const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST) { @@ -192,19 +202,25 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso mult_interleave4x4_height = 2; } - const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d()); - // Check if we need to reshape the matrix A and matrix B const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target); + // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D + if(run_interleave_transpose) + { + reinterpret_input_as_3d = false; + } + + const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d); + if(run_interleave_transpose) { matrix_a_info = &tmp_a_info; matrix_b_info = &tmp_b_info; // Validate interleave kernel - auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height))); - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height)); + auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()))); + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())); // Validate transpose kernel auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width))); |