aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels')
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp30
1 files changed, 24 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp
index 592ee72820..95ece8cdc8 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthfirst_driver.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -173,12 +173,30 @@ class DepthfirstDriver : public DepthwiseCommon<TInput, TWeight, TOutput>
const auto n_output_channels = args.input_channels * args.channel_multiplier;
- for (unsigned int batch = 0; batch < args.n_batches; batch++)
+ // By default we parallelize over the rows, but if there's only 1 row, we
+ // try to parallize over batches
+ auto thread_id_for_rows = thread_id;
+ auto n_threads_for_rows = n_threads;
+ auto thread_id_for_batches = 0;
+ auto n_threads_for_batches = 1;
+ if (args.output_rows == 1) {
+ thread_id_for_rows = 0;
+ n_threads_for_rows = 1;
+ thread_id_for_batches = thread_id;
+ n_threads_for_batches = n_threads;
+ }
+
+ // Progress the pointers for the first batch.
+ input_tensor.base += ld_input_batch*thread_id_for_batches;
+ output_tensor.base += ld_output_batch*thread_id_for_batches;
+ for (unsigned int batch = thread_id_for_batches;
+ batch < args.n_batches;
+ batch += n_threads_for_batches)
{
// Iterate over rows of the output tensor; we stripe over the tiles.
- for (unsigned int start_output_i = thread_id * m_strat->get_output_rows();
+ for (unsigned int start_output_i = thread_id_for_rows * m_strat->get_output_rows();
start_output_i < args.output_rows;
- start_output_i += n_threads * m_strat->get_output_rows())
+ start_output_i += n_threads_for_rows * m_strat->get_output_rows())
{
// Determine what (if any padding) is required on the top/bottom of
// this row of the convolution.
@@ -264,8 +282,8 @@ class DepthfirstDriver : public DepthwiseCommon<TInput, TWeight, TOutput>
}
// Progress the pointers for the next batch.
- input_tensor.base += ld_input_batch;
- output_tensor.base += ld_output_batch;
+ input_tensor.base += ld_input_batch*n_threads_for_batches;
+ output_tensor.base += ld_output_batch*n_threads_for_batches;
}
}