aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp b/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
index 8d3741de96..38092adfee 100644
--- a/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
+++ b/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2023 Arm Limited.
+ * Copyright (c) 2019-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -108,7 +108,11 @@ void CpuDepthwiseConv2dAssemblyDispatch::run(ITensorPack &tensors)
prepare(tensors);
- NEScheduler::get().schedule_op(_pImpl->asm_kernel.get(), Window::DimY, _pImpl->asm_kernel->window(), tensors);
+ // Split over rows (z) if there's more than 1, otherwise batches (w). This logic
+ // corresponds to the threading strategy in DepthFirstDriver::execute_internal
+ auto split_dimension = _pImpl->asm_kernel->window().num_iterations(Window::DimZ) == 1 ? Window::DimZ : Window::DimW;
+
+ NEScheduler::get().schedule_op(_pImpl->asm_kernel.get(), split_dimension, _pImpl->asm_kernel->window(), tensors);
}
void CpuDepthwiseConv2dAssemblyDispatch::prepare(ITensorPack &tensors)