diff options
Diffstat (limited to 'src/runtime/NEON/functions')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index a3080e7f29..24bd7d7a8c 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -280,8 +280,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001 { - const int window_size = _gemm_kernel_asm->get_window_size(); - if(window_size < args._maxthreads) + const unsigned int window_size = get_total_window_size(*_gemm_kernel_asm); + if(window_size < static_cast<unsigned int>(args._maxthreads)) { _gemm_kernel_asm->set_nthreads(window_size); } @@ -404,7 +404,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run() if(_workspace.buffer() != nullptr) { _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer())); - const unsigned int window_size = _gemm_kernel_asm->get_window_size(); + const unsigned int window_size = get_total_window_size(*_gemm_kernel_asm); unsigned int num_threads = NEScheduler::get().num_threads(); if(window_size < num_threads) { @@ -427,14 +427,21 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run() in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d, bias, 0); - // Schedule assembly kernel IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX); if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && _d->info()->data_type() == DataType::F32) { const int granule_threshold = 200; scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold); + + } + else if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && _d->info()->data_type() == DataType::F32) + { + //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions + const int granule_threshold = 200; + scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold); } + NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); } |