diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2018-08-01 15:06:06 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | 597a85666a84c9a9414264966651551564b79299 (patch) | |
tree | 6f2fd1bd8648c495b7e3324433ed902266fb2053 /src/runtime/NEON/functions/NEGEMM.cpp | |
parent | 883f489da93e88d74aa0dfb206c56697ba0e63f0 (diff) | |
download | ComputeLibrary-597a85666a84c9a9414264966651551564b79299.tar.gz |
COMPMID-872 - Rework NEGEMMConvolutionLayer to use NEGEMM
Change-Id: I55f0018ac7214775ebbca63f58a3bf5c93732fec
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/142632
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMM.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMM.cpp | 165 |
1 files changed, 58 insertions, 107 deletions
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index e47ef86a1c..de51266267 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -59,32 +59,20 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe _original_b = b; bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run)); + if(run_optimised) { _asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run); - run_optimised = _asm_glue.is_configured(); + ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured()); } - - // Check if the first input tensor is a vector. - // If so, all the kernels for reshaping the tensors can be skipped - if(_run_vector_matrix_multiplication) + else { - if(!run_optimised) + if(_run_vector_matrix_multiplication) { // Configure the matrix multiply kernel _mm_kernel.configure(a, b, d, alpha, false); } - - // Configure matrix addition kernel - if(beta != 0 && c != nullptr) - { - _ma_kernel.configure(c, d, beta); - _run_addition = true; - } - } - else - { - if(!run_optimised) + else { TensorShape shape_tmp_a = a->info()->tensor_shape(); TensorShape shape_tmp_b = b->info()->tensor_shape(); @@ -128,13 +116,13 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe { _tmp_b.allocator()->allocate(); } + } - // Configure matrix addition kernel - if(beta != 0 && c != nullptr) - { - _ma_kernel.configure(c, d, beta); - _run_addition = true; - } + // Configure matrix addition kernel + if(beta != 0 && c != nullptr) + { + _ma_kernel.configure(c, d, beta); + _run_addition = true; } } } @@ -152,7 +140,8 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso if(c != nullptr) { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(c, 1, DataType::F32, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d()); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0), "The C matrix must have the same number of columns as the matrix B"); @@ -161,110 +150,72 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso if(output->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0)); - ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1)); + if(gemm_info.depth_output_gemm3d() != 1) + { + if(gemm_info.reinterpret_input_as_3d()) + { + ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1)); + ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(2) != output->dimension(2)); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1) * output->dimension(2)); + } + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1)); + } } - // Check if the first input tensor is a vector. - const bool run_vector_matrix_multiplication = a->dimension(1) < 2; - // Check if we need to reshape the matrix A and matrix B - const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run()); // Check if we need to run the optimized assembly kernel const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true)); - const ITensorInfo *matrix_a_info = a; - const ITensorInfo *matrix_b_info = b; - - TensorInfo tmp_a_info{}; - TensorInfo tmp_b_info{}; - TensorInfo tmp_output_info = *output->clone(); - - // Arguments used by GEMMReshapeInfo - // If we pass the matrix A and matrix B reshaped to NEGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to NEGEMMReshapeInfo - // in order to know how the matrices have been reshaped - const int m = a->dimension(1); - const int n = b->dimension(0); - const int k = a->dimension(0); - int mult_transpose1xW_width = 1; - int mult_interleave4x4_height = 1; - - const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d()); - - // Initialize shapes - if(run_interleave_transpose) + if(!run_optimised) { - matrix_a_info = &tmp_a_info; - auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height))); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMM cannot reinterpret the input tensor as 3D"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 1, "NEGEMM cannot reinterpret the output tensor as 3D"); - matrix_b_info = &tmp_b_info; - auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width))); + // Check if the first input tensor is a vector. + const bool run_vector_matrix_multiplication = a->dimension(1) < 2; + // Check if we need to reshape the matrix A and matrix B + const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run()); - auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info))); - } + // Arguments used by GEMMReshapeInfo + // If we pass the matrix A and matrix B reshaped to NEGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to NEGEMMReshapeInfo + // in order to know how the matrices have been reshaped + const int m = a->dimension(1); + const int n = b->dimension(0); + const int k = a->dimension(0); + int mult_transpose1xW_width = 1; + int mult_interleave4x4_height = 1; - // Validate kernels - if(run_optimised && run_interleave_transpose) - { - /* Interleave */ - TensorShape tensor_shape0{ matrix_a_info->tensor_shape() }; - tensor_shape0.set(0, k); - tensor_shape0.set(1, m); + const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d()); - const TensorInfo tensor_info0 = matrix_a_info->clone()->set_tensor_shape(tensor_shape0); - const TensorInfo tensor_info_reshaped0 = matrix_a_info->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(matrix_a_info, &tensor_info_reshaped0); + const ITensorInfo *matrix_a_info = a; + const ITensorInfo *matrix_b_info = b; - if(n != 0) /* Transpose */ - { - TensorShape tensor_shape1{ matrix_b_info->tensor_shape() }; - tensor_shape1.set(0, n); - tensor_shape1.set(1, k); + TensorInfo tmp_a_info{}; + TensorInfo tmp_b_info{}; + TensorInfo tmp_output_info = *output->clone(); - const TensorInfo tensor_info1 = matrix_b_info->clone()->set_tensor_shape(tensor_shape1); - const TensorInfo tensor_info_reshaped1 = matrix_b_info->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(matrix_b_info, &tensor_info_reshaped1); - } - - if(output->total_size() != 0) - { - if(n != 0) - { - ARM_COMPUTE_RETURN_ERROR_ON(tmp_output_info.dimension(0) != static_cast<size_t>(n)); - } - ARM_COMPUTE_RETURN_ERROR_ON(tmp_output_info.dimension(1) != static_cast<size_t>(m)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(matrix_a_info, &tmp_output_info); - } - } - else if(run_vector_matrix_multiplication) - { - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(a, b, output, alpha, false, reshape_info)); - - if(beta != 0 && c != nullptr) - { - // Validate matrix addition kernel - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(c, output); - } - } - else - { if(run_interleave_transpose) { + matrix_a_info = &tmp_a_info; + matrix_b_info = &tmp_b_info; + // Validate interleave kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, matrix_a_info)); + auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()))); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &tmp_a_info)); // Validate transpose kernel - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, matrix_b_info)); + auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width))); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &tmp_b_info)); } // Validate matrix multiply + auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info))); ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info)); - - if(beta != 0 && c != nullptr) - { - // Validate matrix addition kernel - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, &tmp_output_info); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(c, &tmp_output_info); - } } return Status{}; |