aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp15
1 files changed, 11 insertions, 4 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index a3080e7f29..24bd7d7a8c 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -280,8 +280,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c
//if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
//the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
{
- const int window_size = _gemm_kernel_asm->get_window_size();
- if(window_size < args._maxthreads)
+ const unsigned int window_size = get_total_window_size(*_gemm_kernel_asm);
+ if(window_size < static_cast<unsigned int>(args._maxthreads))
{
_gemm_kernel_asm->set_nthreads(window_size);
}
@@ -404,7 +404,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
if(_workspace.buffer() != nullptr)
{
_gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer()));
- const unsigned int window_size = _gemm_kernel_asm->get_window_size();
+ const unsigned int window_size = get_total_window_size(*_gemm_kernel_asm);
unsigned int num_threads = NEScheduler::get().num_threads();
if(window_size < num_threads)
{
@@ -427,14 +427,21 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
in1_ptr, ldb, multi_stride_b,
out_ptr, ldd, batch_stride_d, multi_stride_d,
bias, 0);
-
// Schedule assembly kernel
IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX);
if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && _d->info()->data_type() == DataType::F32)
{
const int granule_threshold = 200;
scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold);
+
+ }
+ else if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && _d->info()->data_type() == DataType::F32)
+ {
+ //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions
+ const int granule_threshold = 200;
+ scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
}
+
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
}