diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/runtime/CPP/CPPScheduler.cpp | 4 | ||||
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 20 |
2 files changed, 16 insertions, 8 deletions
diff --git a/src/runtime/CPP/CPPScheduler.cpp b/src/runtime/CPP/CPPScheduler.cpp index 5849218536..e684eeee98 100644 --- a/src/runtime/CPP/CPPScheduler.cpp +++ b/src/runtime/CPP/CPPScheduler.cpp @@ -338,9 +338,9 @@ void CPPScheduler::schedule(ICPPKernel *kernel, const Hints &hints) break; case StrategyHint::DYNAMIC: { + const unsigned int granule_threshold = (hints.threshold() <= 0) ? num_threads : static_cast<unsigned int>(hints.threshold()); // Make sure we don't use some windows which are too small as this might create some contention on the ThreadFeeder - const unsigned int max_iterations = static_cast<unsigned int>(_impl->_num_threads) * 3; - num_windows = num_iterations > max_iterations ? max_iterations : num_iterations; + num_windows = num_iterations > granule_threshold ? granule_threshold : num_iterations; break; } default: diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 43e531579a..88e060109a 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -201,6 +201,8 @@ private: IWeightsManager *_weights_manager{ nullptr }; /** Weights transform object */ FallbackTransform<TypeInput, TypeOutput> _weights_transform{}; + /** GEMM kernel description */ + arm_gemm::KernelDescription _kernel_info{}; }; template <typename TypeInput, typename TypeOutput, class OutputStage> @@ -208,12 +210,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) { - arm_gemm::GemmConfig gemm_cfg; - const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os); - _weights_manager = weights_manager; - if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) + arm_gemm::GemmConfig gemm_cfg; + _kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os); + _weights_manager = weights_manager; + if(_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) { - gemm_cfg.filter = gemm_kernel_info.name; + gemm_cfg.filter = _kernel_info.name; args._cfg = &gemm_cfg; } _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os); @@ -387,7 +389,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run() bias, 0); // Schedule assembly kernel - NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); + IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX); + if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED) + { + constexpr int granule_threshold = 200; + scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold); + } + NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); } template <typename TypeInput, typename TypeOutput> |