aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMM.cpp
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-07-16 17:20:38 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commita855af10a486c53c2271361cb87f349eca64b749 (patch)
treeb326b63bdcaf76c9620b1bbf22942d4683503a65 /src/runtime/NEON/functions/NEGEMM.cpp
parent5a3ee4f708a9e1642b0211955ff905e7b67e831d (diff)
downloadComputeLibrary-a855af10a486c53c2271361cb87f349eca64b749.tar.gz
COMPMID-1401 Implement NEFullyConnectedLayer for QASYMM8
Change-Id: I0404df6d369855e2f458f2db8f26e81c80a1ee87 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140148 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMM.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp154
1 files changed, 137 insertions, 17 deletions
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index c958904b93..e47ef86a1c 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -23,12 +23,14 @@
*/
#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
+#include "arm_compute/core/CPP/Validate.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/runtime/TensorAllocator.h"
@@ -36,6 +38,8 @@
#include <cmath>
+using namespace arm_compute::misc::shape_calculator;
+
namespace arm_compute
{
NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager)
@@ -46,21 +50,7 @@ NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager)
void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, d);
- ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(0) != b->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
- ARM_COMPUTE_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
- ARM_COMPUTE_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
-
- if(c != nullptr)
- {
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(c, 1, DataType::F32, DataType::F16);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, c);
- ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(1) != c->info()->dimension(1), "The C matrix must have the same number of rows as the matrix A");
- ARM_COMPUTE_ERROR_ON_MSG(b->info()->dimension(0) != c->info()->dimension(0), "The C matrix must have the same number of columns as the matrix B");
- ARM_COMPUTE_ERROR_ON_MSG(c->info()->dimension(0) != d->info()->dimension(0), "The C matrix must have the same number of rows as the output matrix");
- ARM_COMPUTE_ERROR_ON_MSG(c->info()->dimension(1) != d->info()->dimension(1), "The C matrix must have the same number of columns as the output matrix");
- }
+ ARM_COMPUTE_ERROR_THROW_ON(NEGEMM::validate(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info));
// Check if we need to reshape the matrix B only on the first run
_is_prepared = false;
@@ -68,7 +58,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
_run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
_original_b = b;
- bool run_optimised = a->info()->data_type() == DataType::F32 && (c == nullptr || beta == 0.f);
+ bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run));
if(run_optimised)
{
_asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run);
@@ -149,6 +139,137 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
}
}
+Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+ ARM_COMPUTE_UNUSED(alpha);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
+
+ if(c != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(c, 1, DataType::F32, DataType::F16);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0), "The C matrix must have the same number of columns as the matrix B");
+ }
+
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != output->dimension(1));
+ }
+
+ // Check if the first input tensor is a vector.
+ const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
+ // Check if we need to reshape the matrix A and matrix B
+ const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
+ // Check if we need to run the optimized assembly kernel
+ const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true));
+
+ const ITensorInfo *matrix_a_info = a;
+ const ITensorInfo *matrix_b_info = b;
+
+ TensorInfo tmp_a_info{};
+ TensorInfo tmp_b_info{};
+ TensorInfo tmp_output_info = *output->clone();
+
+ // Arguments used by GEMMReshapeInfo
+ // If we pass the matrix A and matrix B reshaped to NEGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to NEGEMMReshapeInfo
+ // in order to know how the matrices have been reshaped
+ const int m = a->dimension(1);
+ const int n = b->dimension(0);
+ const int k = a->dimension(0);
+ int mult_transpose1xW_width = 1;
+ int mult_interleave4x4_height = 1;
+
+ const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
+
+ // Initialize shapes
+ if(run_interleave_transpose)
+ {
+ matrix_a_info = &tmp_a_info;
+ auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height)));
+
+ matrix_b_info = &tmp_b_info;
+ auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
+
+ auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
+ }
+
+ // Validate kernels
+ if(run_optimised && run_interleave_transpose)
+ {
+ /* Interleave */
+ TensorShape tensor_shape0{ matrix_a_info->tensor_shape() };
+ tensor_shape0.set(0, k);
+ tensor_shape0.set(1, m);
+
+ const TensorInfo tensor_info0 = matrix_a_info->clone()->set_tensor_shape(tensor_shape0);
+ const TensorInfo tensor_info_reshaped0 = matrix_a_info->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(matrix_a_info, &tensor_info_reshaped0);
+
+ if(n != 0) /* Transpose */
+ {
+ TensorShape tensor_shape1{ matrix_b_info->tensor_shape() };
+ tensor_shape1.set(0, n);
+ tensor_shape1.set(1, k);
+
+ const TensorInfo tensor_info1 = matrix_b_info->clone()->set_tensor_shape(tensor_shape1);
+ const TensorInfo tensor_info_reshaped1 = matrix_b_info->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(matrix_b_info, &tensor_info_reshaped1);
+ }
+
+ if(output->total_size() != 0)
+ {
+ if(n != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(tmp_output_info.dimension(0) != static_cast<size_t>(n));
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON(tmp_output_info.dimension(1) != static_cast<size_t>(m));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(matrix_a_info, &tmp_output_info);
+ }
+ }
+ else if(run_vector_matrix_multiplication)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(a, b, output, alpha, false, reshape_info));
+
+ if(beta != 0 && c != nullptr)
+ {
+ // Validate matrix addition kernel
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(c, output);
+ }
+ }
+ else
+ {
+ if(run_interleave_transpose)
+ {
+ // Validate interleave kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(a, matrix_a_info));
+
+ // Validate transpose kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMTranspose1xWKernel::validate(b, matrix_b_info));
+ }
+
+ // Validate matrix multiply
+ ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
+
+ if(beta != 0 && c != nullptr)
+ {
+ // Validate matrix addition kernel
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, &tmp_output_info);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(c, &tmp_output_info);
+ }
+ }
+
+ return Status{};
+}
+
void NEGEMM::run()
{
prepare();
@@ -196,7 +317,6 @@ void NEGEMM::prepare()
ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
_asm_glue.prepare();
- _original_b->mark_as_unused();
}
else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured())
{