aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorAnthony Barbier <anthony.barbier@arm.com>2018-07-17 16:48:42 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitc8e84b5a3872eda6748d77dbaf8548ad99f4c0cd (patch)
tree0c519a97b7f0ff89352a7736be1cae43b6dea10e /src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
parent3efb37536149f438a68a1742c35d827e1fbd7860 (diff)
downloadComputeLibrary-c8e84b5a3872eda6748d77dbaf8548ad99f4c0cd.tar.gz
COMPMID-1405: Create our own gemm_native kernel / function.
Change-Id: Ie0a80bd6b4eb5632cac63ccf54bcb07d4309da19 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140305 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp114
1 files changed, 88 insertions, 26 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index e796a6a56e..f4710fab84 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -23,21 +23,74 @@
*/
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
-using namespace arm_compute;
-
+namespace arm_compute
+{
template <typename TypeInput, typename TypeOutput>
NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
: _function(nullptr), _arm_gemm(), _memory_group(std::move(memory_manager))
{
}
+template <>
+bool NEGEMMAssemblyDispatch<float, float>::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+{
+ ARM_COMPUTE_UNUSED(method);
+ ARM_COMPUTE_UNUSED(a);
+ ARM_COMPUTE_UNUSED(b);
+ ARM_COMPUTE_UNUSED(d);
+ ARM_COMPUTE_UNUSED(alpha);
+ ARM_COMPUTE_UNUSED(beta);
+ ARM_COMPUTE_UNUSED(pretranspose_hint);
+ switch(method)
+ {
+#ifdef __aarch64__
+ case arm_gemm::GemmMethod::GEMM_NATIVE:
+ {
+ auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
+ kernel->configure(a, b, d, alpha, beta);
+ auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
+ function->configure(std::move(kernel));
+ _function = std::move(function);
+ return true;
+ }
+#endif /* __aarch64__ */
+ default:
+ return false;
+ }
+}
+
+template <typename TypeInput, typename TypeOutput>
+bool NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+{
+ ARM_COMPUTE_UNUSED(method);
+ ARM_COMPUTE_UNUSED(a);
+ ARM_COMPUTE_UNUSED(b);
+ ARM_COMPUTE_UNUSED(d);
+ ARM_COMPUTE_UNUSED(alpha);
+ ARM_COMPUTE_UNUSED(beta);
+ ARM_COMPUTE_UNUSED(pretranspose_hint);
+ return false;
+}
+
template <typename TypeInput, typename TypeOutput>
void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
{
- //TODO(antbar01) Check heuristics here to figure out if we should use an ACL IFunction
- _arm_gemm.configure(a, b, d, alpha, beta, pretranspose_hint, _memory_group);
+ INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ unsigned int num_threads = NEScheduler::get().num_threads();
+
+ arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+
+ //Try to create an ACL function:
+ if(!create_function(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint))
+ {
+ //Fallback onto arm_gemm function if ACL doesn't support this method.
+ _arm_gemm.configure(a, b, d, args, _memory_group);
+ }
}
template <typename TypeInput, typename TypeOutput>
@@ -75,10 +128,8 @@ void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::run()
}
#ifndef __aarch64__
-namespace arm_compute
-{
template <>
-void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
{
// arm_gemm::gemm for 8bit only exists for aarch64
ARM_COMPUTE_UNUSED(a);
@@ -87,11 +138,11 @@ void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::Fallback::configure(const ITenso
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(beta);
ARM_COMPUTE_UNUSED(pretranspose_hint);
- ARM_COMPUTE_UNUSED(memory_group);
+ ARM_COMPUTE_ERROR("Not supported for this architecture");
}
template <>
-void NEGEMMAssemblyDispatch<int8_t, int32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<int8_t, int32_t>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
{
// arm_gemm::gemm for 8bit only exists for aarch64
ARM_COMPUTE_UNUSED(a);
@@ -100,23 +151,37 @@ void NEGEMMAssemblyDispatch<int8_t, int32_t>::Fallback::configure(const ITensor
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(beta);
ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_ERROR("Not supported for this architecture");
+}
+
+template <>
+void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<uint32_t> &args, MemoryGroup &memory_group)
+{
+ // arm_gemm::gemm for 8bit only exists for aarch64
+ ARM_COMPUTE_UNUSED(a);
+ ARM_COMPUTE_UNUSED(b);
+ ARM_COMPUTE_UNUSED(d);
+ ARM_COMPUTE_UNUSED(args);
ARM_COMPUTE_UNUSED(memory_group);
+ ARM_COMPUTE_ERROR("Not supported for this architecture");
}
-} //namespace arm_compute
+template <>
+void NEGEMMAssemblyDispatch<int8_t, int32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<int32_t> &args, MemoryGroup &memory_group)
+{
+ // arm_gemm::gemm for 8bit only exists for aarch64
+ ARM_COMPUTE_UNUSED(a);
+ ARM_COMPUTE_UNUSED(b);
+ ARM_COMPUTE_UNUSED(d);
+ ARM_COMPUTE_UNUSED(args);
+ ARM_COMPUTE_UNUSED(memory_group);
+ ARM_COMPUTE_ERROR("Not supported for this architecture");
+}
#endif // aarch64
template <typename TypeInput, typename TypeOutput>
-void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
{
- const CPUInfo &ci = NEScheduler::get().cpu_info();
- const int M = d->info()->tensor_shape().y();
- const int N = d->info()->tensor_shape().x();
- const int K = a->info()->tensor_shape().x();
- const int batches = d->info()->tensor_shape().total_size_upper(2);
- const int multis = b->info()->tensor_shape().z();
- unsigned int num_threads = NEScheduler::get().num_threads();
-
- _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr);
if(_gemm_kernel_asm == nullptr)
{
//configuration not supported: Leave function unconfigured:
@@ -139,11 +204,10 @@ void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::configure(const IT
//if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
//the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
{
- const unsigned int window_size = _gemm_kernel_asm->get_window_size();
- if(window_size < num_threads)
+ const int window_size = _gemm_kernel_asm->get_window_size();
+ if(window_size < args._maxthreads)
{
- num_threads = window_size;
- _gemm_kernel_asm->set_nthreads(num_threads);
+ _gemm_kernel_asm->set_nthreads(window_size);
}
}
@@ -248,8 +312,6 @@ void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::run()
NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
}
-namespace arm_compute
-{
template class NEGEMMAssemblyDispatch<float, float>;
template class NEGEMMAssemblyDispatch<uint8_t, uint32_t>;
template class NEGEMMAssemblyDispatch<int8_t, int32_t>;