diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_conv')
-rw-r--r-- | src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp index 592ee72820..95ece8cdc8 100644 --- a/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp +++ b/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -173,12 +173,30 @@ class DepthfirstDriver : public DepthwiseCommon<TInput, TWeight, TOutput> const auto n_output_channels = args.input_channels * args.channel_multiplier; - for (unsigned int batch = 0; batch < args.n_batches; batch++) + // By default we parallelize over the rows, but if there's only 1 row, we + // try to parallize over batches + auto thread_id_for_rows = thread_id; + auto n_threads_for_rows = n_threads; + auto thread_id_for_batches = 0; + auto n_threads_for_batches = 1; + if (args.output_rows == 1) { + thread_id_for_rows = 0; + n_threads_for_rows = 1; + thread_id_for_batches = thread_id; + n_threads_for_batches = n_threads; + } + + // Progress the pointers for the first batch. + input_tensor.base += ld_input_batch*thread_id_for_batches; + output_tensor.base += ld_output_batch*thread_id_for_batches; + for (unsigned int batch = thread_id_for_batches; + batch < args.n_batches; + batch += n_threads_for_batches) { // Iterate over rows of the output tensor; we stripe over the tiles. - for (unsigned int start_output_i = thread_id * m_strat->get_output_rows(); + for (unsigned int start_output_i = thread_id_for_rows * m_strat->get_output_rows(); start_output_i < args.output_rows; - start_output_i += n_threads * m_strat->get_output_rows()) + start_output_i += n_threads_for_rows * m_strat->get_output_rows()) { // Determine what (if any padding) is required on the top/bottom of // this row of the convolution. @@ -264,8 +282,8 @@ class DepthfirstDriver : public DepthwiseCommon<TInput, TWeight, TOutput> } // Progress the pointers for the next batch. - input_tensor.base += ld_input_batch; - output_tensor.base += ld_output_batch; + input_tensor.base += ld_input_batch*n_threads_for_batches; + output_tensor.base += ld_output_batch*n_threads_for_batches; } } |