aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/assembly
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-09 18:35:17 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-18 13:41:40 +0000
commit7cd26d4a1b14bc4bf7c61496803416ab3d84791f (patch)
tree12cc4a27d7ecebc69a43e96b1f46c7eb05437978 /src/runtime/NEON/functions/assembly
parent3ac2f3a1d9297220d1b0ce920dd13fdd4edcc187 (diff)
downloadComputeLibrary-7cd26d4a1b14bc4bf7c61496803416ab3d84791f.tar.gz
COMPMID-1867: Add NEON/SVE GEMM Hybrid kernels.
Change-Id: Ib40a9921e7f9a6a8be6c38872d6b3a0f24ed0cd3 Reviewed-on: https://review.mlplatform.org/515 Reviewed-by: Anthony Barbier <Anthony.barbier@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/assembly')
-rw-r--r--src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp142
1 files changed, 22 insertions, 120 deletions
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
index fe998a0e42..695fc859de 100644
--- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
+++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,12 +26,11 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/kernels/assembly/Helpers.h"
-#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
-#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
-#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h"
+
#include <atomic>
#include <condition_variable>
#include <mutex>
@@ -179,6 +178,7 @@ NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManage
: _memory_group(std::move(memory_manager))
{
}
+
void NEGEMMInterleavedWrapper::run()
{
prepare();
@@ -334,38 +334,7 @@ void NEGEMMInterleavedWrapper::prepare()
}
}
-namespace
-{
-// Factory to instantiate NEGEMMInterleavedPrepareBWrapperKernel:
-template <typename InputType, bool use_dot = false>
-std::unique_ptr<NEGEMMInterleavedPrepareBWrapperKernel> instantiate_prepareB(const ITensor *b, ITensor *transformed_b, const INEGEMMWrapperKernel::Params &params)
-{
- auto prepare_b = support::cpp14::make_unique<NEGEMMInterleavedPrepareBWrapperKernelTemplate<InputType, use_dot>>();
- prepare_b->configure(b, transformed_b, false, NEScheduler::get().cpu_info(), params);
- return std::move(prepare_b);
-}
-
-// Factory to instantiate NEGEMMInterleavedTransformAWrapperTemplate:
-template <typename InputType, bool use_dot = false>
-std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a, ITensor *transformed_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
-{
- auto transform_a = support::cpp14::make_unique<NEGEMMInterleavedTransformAWrapperTemplate<InputType, use_dot>>();
- transform_a->configure(a, transformed_a, false, block_walker, params);
- return std::move(transform_a);
-}
-
-// Factory to instantiate NEGEMMInterleavedTransformAWrapperTemplate:
-template <typename InputType, typename OutputType, bool use_dot = false>
-std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker,
- const BlockSizes &block_sizes, const INEGEMMWrapperKernel::Params &params, bool pretranspose_b, float alpha, float beta)
-{
- auto matrix_multiply = support::cpp14::make_unique<NEGEMMInterleavedMatrixMultiplyWrapperTemplate<InputType, OutputType, use_dot>>();
- matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, NEScheduler::get().num_threads());
- return std::move(matrix_multiply);
-}
-} // namespace
-
-void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b, bool use_dot)
+void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b)
{
_params = INEGEMMWrapperKernel::extract_parameters(a, b, c);
_a = a;
@@ -373,18 +342,26 @@ void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITe
_c = c;
_pretranspose_b = pretranspose_b;
- DataType input_type = a->info()->data_type();
+ const DataType input_type = a->info()->data_type();
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ const unsigned int num_threads = NEScheduler::get().num_threads();
+
+ const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, pretranspose_b);
+ ARM_COMPUTE_ERROR_ON(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMM_INTERLEAVED);
// Forcing 128-byte alignment (required by 32-bit kernels)
const unsigned int alignment = 128;
_transformed_b.allocator()->init(TensorInfo{}, alignment);
_tmp_c.allocator()->init(TensorInfo{}, alignment);
- _tag = "NEGEMMInterleaved_";
- _tag += get_strategy_name(input_type, use_dot);
+ _tag = "NEGEMMInterleaved_" + gemm_kernel_info.name;
+
+ // Get strategy
+ std::unique_ptr<detail::IInterleavedStrategy> strategy = detail::create_strategy(gemm_kernel_info.name);
+ ARM_COMPUTE_ERROR_ON(strategy == nullptr);
if(!_pretranspose_b)
{
- _block_sizes = calculate_block_sizes_from_data_type(NEScheduler::get().cpu_info(), _params.M, _params.N, _params.K, input_type, use_dot);
+ _block_sizes = strategy->calculate_block_sizes_for_strategy(ci, _params);
_batch_window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_block_sizes.m_round, _block_sizes.strategy_out_height), _block_sizes.strategy_out_height));
_batch_window.set(Window::DimY, Window::Dimension(0, _params.batches));
// If the execution is single threaded or has only one window then the buffer manager only needs 1 buffer else we will use NUM_BUFFERS buffers and ping pong between them:
@@ -409,43 +386,8 @@ void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITe
{
_tag += "_preB";
}
- switch(input_type)
- {
- case DataType::F32:
- _prepare_b = instantiate_prepareB<float>(_b, &_transformed_b, _params);
- break;
-#ifdef __aarch64__
- case DataType::U8:
- case DataType::QASYMM8:
- if(use_dot)
- {
- _prepare_b = instantiate_prepareB<uint8_t, true>(_b, &_transformed_b, _params);
- }
- else
- {
- _prepare_b = instantiate_prepareB<uint8_t, false>(_b, &_transformed_b, _params);
- }
- break;
- case DataType::S8:
- if(use_dot)
- {
- _prepare_b = instantiate_prepareB<int8_t, true>(_b, &_transformed_b, _params);
- }
- else
- {
- _prepare_b = instantiate_prepareB<int8_t, false>(_b, &_transformed_b, _params);
- }
- break;
-#endif /* __aarch64__ */
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- _prepare_b = instantiate_prepareB<__fp16>(_b, &_transformed_b, _params);
- break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- ARM_COMPUTE_ERROR("DataType not supported");
- break;
- }
+
+ _prepare_b = strategy->instantiate_prepareB(b, &_transformed_b, _params, ci);
ARM_COMPUTE_ERROR_ON(_prepare_b == nullptr);
if(_pretranspose_b)
@@ -463,51 +405,11 @@ void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITe
_memory_group.manage(&_transformed_a);
_memory_group.manage(&_tmp_c);
- switch(input_type)
- {
- case DataType::F32:
- _transform_a = instantiate_transformA<float>(_a, &_transformed_a, _block_walker, _params);
- _matrix_multiply = instantiate_matrix_multiply<float, float>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
- break;
-#ifdef __aarch64__
- case DataType::U8:
- case DataType::QASYMM8:
- if(use_dot)
- {
- _transform_a = instantiate_transformA<uint8_t, true>(_a, &_transformed_a, _block_walker, _params);
- _matrix_multiply = instantiate_matrix_multiply<uint8_t, uint32_t, true>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
- }
- else
- {
- _transform_a = instantiate_transformA<uint8_t, false>(_a, &_transformed_a, _block_walker, _params);
- _matrix_multiply = instantiate_matrix_multiply<uint8_t, uint32_t, false>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
- }
- break;
- case DataType::S8:
- if(use_dot)
- {
- _transform_a = instantiate_transformA<int8_t, true>(_a, &_transformed_a, _block_walker, _params);
- _matrix_multiply = instantiate_matrix_multiply<int8_t, int32_t, true>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
- }
- else
- {
- _transform_a = instantiate_transformA<int8_t, false>(_a, &_transformed_a, _block_walker, _params);
- _matrix_multiply = instantiate_matrix_multiply<int8_t, int32_t, false>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
- }
- break;
-#endif /* __aarch64__ */
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- _transform_a = instantiate_transformA<__fp16>(_a, &_transformed_a, _block_walker, _params);
- _matrix_multiply = instantiate_matrix_multiply<__fp16, __fp16>(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, pretranspose_b, alpha, beta);
- break;
- break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- break;
- }
+ _transform_a = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params);
+ _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, pretranspose_b, num_threads);
ARM_COMPUTE_ERROR_ON(_transform_a == nullptr);
ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr);
+
_transformed_a.allocator()->allocate();
_tmp_c.allocator()->allocate();
if(!_pretranspose_b)