aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMM.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-08-01 15:06:06 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit597a85666a84c9a9414264966651551564b79299 (patch)
tree6f2fd1bd8648c495b7e3324433ed902266fb2053 /src/runtime/NEON/functions/NEGEMM.cpp
parent883f489da93e88d74aa0dfb206c56697ba0e63f0 (diff)
downloadComputeLibrary-597a85666a84c9a9414264966651551564b79299.tar.gz
COMPMID-872 - Rework NEGEMMConvolutionLayer to use NEGEMM
Change-Id: I55f0018ac7214775ebbca63f58a3bf5c93732fec Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/142632 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMM.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp165
1 files changed, 58 insertions, 107 deletions
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index e47ef86a1c..de51266267 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -59,32 +59,20 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
_original_b = b;
bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run));
+
if(run_optimised)
{
_asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run);
- run_optimised = _asm_glue.is_configured();
+ ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured());
}
-
- // Check if the first input tensor is a vector.
- // If so, all the kernels for reshaping the tensors can be skipped
- if(_run_vector_matrix_multiplication)
+ else
{
- if(!run_optimised)
+ if(_run_vector_matrix_multiplication)
{
// Configure the matrix multiply kernel
_mm_kernel.configure(a, b, d, alpha, false);
}
-
- // Configure matrix addition kernel
- if(beta != 0 && c != nullptr)
- {
- _ma_kernel.configure(c, d, beta);
- _run_addition = true;
- }
- }
- else
- {
- if(!run_optimised)
+ else
{
TensorShape shape_tmp_a = a->info()->tensor_shape();
TensorShape shape_tmp_b = b->info()->tensor_shape();
@@ -128,13 +116,13 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
{
_tmp_b.allocator()->allocate();
}
+ }
- // Configure matrix addition kernel
- if(beta != 0 && c != nullptr)
- {
- _ma_kernel.configure(c, d, beta);
- _run_addition = true;
- }
+ // Configure matrix addition kernel
+ if(beta != 0 && c != nullptr)
+ {
+ _ma_kernel.configure(c, d, beta);
+ _run_addition = true;
}
}
}
@@ -152,7 +140,8 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
if(c != nullptr)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(c, 1, DataType::F32, DataType::F16);
+ ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0), "The C matrix must have the same number of columns as the matrix B");
@@ -161,110 +150,72 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0));
- ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
+ if(gemm_info.depth_output_gemm3d() != 1)
+ {
+ if(gemm_info.reinterpret_input_as_3d())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(2) != output->dimension(2));
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1) * output->dimension(2));
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
+ }
}
- // Check if the first input tensor is a vector.
- const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
- // Check if we need to reshape the matrix A and matrix B
- const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
// Check if we need to run the optimized assembly kernel
const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true));
- const ITensorInfo *matrix_a_info = a;
- const ITensorInfo *matrix_b_info = b;
-
- TensorInfo tmp_a_info{};
- TensorInfo tmp_b_info{};
- TensorInfo tmp_output_info = *output->clone();
-
- // Arguments used by GEMMReshapeInfo
- // If we pass the matrix A and matrix B reshaped to NEGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to NEGEMMReshapeInfo
- // in order to know how the matrices have been reshaped
- const int m = a->dimension(1);
- const int n = b->dimension(0);
- const int k = a->dimension(0);
- int mult_transpose1xW_width = 1;
- int mult_interleave4x4_height = 1;
-
- const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
-
- // Initialize shapes
- if(run_interleave_transpose)
+ if(!run_optimised)
{
- matrix_a_info = &tmp_a_info;
- auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height)));
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMM cannot reinterpret the input tensor as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 1, "NEGEMM cannot reinterpret the output tensor as 3D");
- matrix_b_info = &tmp_b_info;
- auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
+ // Check if the first input tensor is a vector.
+ const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
+ // Check if we need to reshape the matrix A and matrix B
+ const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
- auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
- }
+ // Arguments used by GEMMReshapeInfo
+ // If we pass the matrix A and matrix B reshaped to NEGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to NEGEMMReshapeInfo
+ // in order to know how the matrices have been reshaped
+ const int m = a->dimension(1);
+ const int n = b->dimension(0);
+ const int k = a->dimension(0);
+ int mult_transpose1xW_width = 1;
+ int mult_interleave4x4_height = 1;
- // Validate kernels
- if(run_optimised && run_interleave_transpose)
- {
- /* Interleave */
- TensorShape tensor_shape0{ matrix_a_info->tensor_shape() };
- tensor_shape0.set(0, k);
- tensor_shape0.set(1, m);
+ const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
- const TensorInfo tensor_info0 = matrix_a_info->clone()->set_tensor_shape(tensor_shape0);
- const TensorInfo tensor_info_reshaped0 = matrix_a_info->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(matrix_a_info, &tensor_info_reshaped0);
+ const ITensorInfo *matrix_a_info = a;
+ const ITensorInfo *matrix_b_info = b;
- if(n != 0) /* Transpose */
- {
- TensorShape tensor_shape1{ matrix_b_info->tensor_shape() };
- tensor_shape1.set(0, n);
- tensor_shape1.set(1, k);
+ TensorInfo tmp_a_info{};
+ TensorInfo tmp_b_info{};
+ TensorInfo tmp_output_info = *output->clone();
- const TensorInfo tensor_info1 = matrix_b_info->clone()->set_tensor_shape(tensor_shape1);
- const TensorInfo tensor_info_reshaped1 = matrix_b_info->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(matrix_b_info, &tensor_info_reshaped1);
- }
-
- if(output->total_size() != 0)
- {
- if(n != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON(tmp_output_info.dimension(0) != static_cast<size_t>(n));
- }
- ARM_COMPUTE_RETURN_ERROR_ON(tmp_output_info.dimension(1) != static_cast<size_t>(m));
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(matrix_a_info, &tmp_output_info);
- }
- }
- else if(run_vector_matrix_multiplication)
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(a, b, output, alpha, false, reshape_info));
-
- if(beta != 0 && c != nullptr)
- {
- // Validate matrix addition kernel
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(c, output);
- }
- }
- else
- {
if(run_interleave_transpose)
{
+ matrix_a_info = &tmp_a_info;
+ matrix_b_info = &tmp_b_info;
+
// Validate interleave kernel
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, matrix_a_info));
+ auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, &tmp_a_info));
// Validate transpose kernel
- ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, matrix_b_info));
+ auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, &tmp_b_info));
}
// Validate matrix multiply
+ auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
-
- if(beta != 0 && c != nullptr)
- {
- // Validate matrix addition kernel
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, &tmp_output_info);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(c, &tmp_output_info);
- }
}
return Status{};