aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorAnthony Barbier <anthony.barbier@arm.com>2018-07-23 16:42:59 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit3d677ccee046cd384abf2142f323f8e9e7a4834f (patch)
tree2e0d86a1b2438cb94386c55d1bc89b3e1061214c /src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
parent597a85666a84c9a9414264966651551564b79299 (diff)
downloadComputeLibrary-3d677ccee046cd384abf2142f323f8e9e7a4834f.tar.gz
COMPMID-1406: Refactor gemm_interleaved to use our own types and scheduler
- Ported PrepareB kernel from gemm_interleave - Ported TransformA feature from gemm_interleave - Allocate reshaped a and b buffers - Added memory_manager / memory_group - MatrixMultiply kernel - Interleave kernels execution. - Fixed a few bugs: all nightly Convolution tests passing for threads=1 and threads=4 - Added Doxygen documentations and comments in the code - Added support for all data types supported Change-Id: Iffa1c09fda0bb9c61213bb83524d5a48e7ecb03c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141281 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp111
1 files changed, 93 insertions, 18 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index f17da7d2e4..8ba620fe51 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -24,9 +24,13 @@
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/core/CPP/Validate.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
+#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
#include <arm_neon.h>
@@ -34,8 +38,31 @@ namespace arm_compute
{
namespace
{
+std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+
+{
+ //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+}
+
template <typename TypeInput, typename TypeOutput>
-std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
{
ARM_COMPUTE_UNUSED(method);
ARM_COMPUTE_UNUSED(a);
@@ -44,21 +71,63 @@ std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const IT
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(beta);
ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_UNUSED(memory_manager);
return nullptr;
}
+
+#ifdef __aarch64__
template <>
-std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+ return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+ return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
{
- ARM_COMPUTE_UNUSED(method);
- ARM_COMPUTE_UNUSED(a);
- ARM_COMPUTE_UNUSED(b);
- ARM_COMPUTE_UNUSED(d);
- ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_UNUSED(beta);
ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_UNUSED(memory_manager);
switch(method)
{
-#ifdef __aarch64__
case arm_gemm::GemmMethod::GEMM_NATIVE:
{
auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
@@ -67,11 +136,11 @@ std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod me
function->configure(std::move(kernel));
return std::move(function);
}
-#endif /* __aarch64__ */
default:
return nullptr;
}
}
+#endif /* __aarch64__ */
/** Fallback in case ACL doesn't have a function */
template <typename TypeInput, typename TypeOutput>
@@ -173,11 +242,11 @@ void Fallback<TypeInput, TypeOutput>::prepare()
// Pretranspose B if required
if(_gemm_kernel_asm->B_pretranspose_required())
{
+ ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
_gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
_b->mark_as_unused();
}
@@ -260,7 +329,7 @@ void Fallback<TypeInput, TypeOutput>::run()
template <typename TypeInput, typename TypeOutput>
void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
- ITensor *d, float alpha, float beta, bool pretranspose_hint)
+ ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
const CPUInfo &ci = NEScheduler::get().cpu_info();
@@ -269,7 +338,13 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::
arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
//Try to create an ACL function:
- acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint);
+ acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+ // If the type agnostic factory failed to create an ACL function, try the specialised one:
+ if(acl_function == nullptr)
+ {
+ acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+ }
+ //If we still don't have an ACL function:
if(acl_function == nullptr)
{
//Fallback onto arm_gemm function if ACL doesn't support this method.
@@ -282,7 +357,7 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::
} //namespace
NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
- : _function(nullptr), _arm_gemm(nullptr), _memory_group(std::move(memory_manager))
+ : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
{
}
@@ -321,20 +396,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITens
switch(a->info()->data_type())
{
case DataType::F32:
- create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+ create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+ create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
break;
case DataType::S8:
- create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+ create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+ create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default: