aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp50
1 files changed, 45 insertions, 5 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 9af98be41d..9c85631406 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -38,6 +38,46 @@ namespace arm_compute
{
namespace cpu
{
+namespace
+{
+/** Run pretranspose_B_array in parallel (1D static scheduling)
+ *
+ * @tparam TypeInput
+ * @tparam TypeOutput
+ *
+ * @param[in] gemm_asm GemmCommon kernel to run
+ * @param[in] dst Pretransposed B array
+ * @param[in] src B array to be pretransposed
+ * @param[in] src_ld Stride in y
+ * @param[in] src_multi_stride Stride in z ("multi")
+ * @param[in] num_threads Number of threads to run this method. Must be >= 1
+ */
+template <typename TypeInput, typename TypeOutput>
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, ITensor *dst, const TypeInput *src, int src_ld, int src_multi_stride, unsigned int num_threads)
+{
+ ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
+ ARM_COMPUTE_ERROR_ON(num_threads == 0);
+ // The window size is also the total workload size
+ const unsigned int wsize = gemm_asm->get_B_pretranspose_window_size();
+
+ std::vector<IScheduler::Workload> workloads(num_threads);
+ for(unsigned int t = 0; t < num_threads; ++t)
+ {
+ workloads[t] = [ = ](const ThreadInfo & info)
+ {
+ const unsigned int start = (info.thread_id * wsize) / num_threads;
+ const unsigned int end = ((info.thread_id + 1) * wsize) / num_threads;
+
+ if(start < end)
+ {
+ gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end);
+ }
+ };
+ }
+ NEScheduler::get().run_tagged_workloads(workloads, "CpuGemmAssemblyDispatch/pretranspose_B_array");
+}
+} // namespace
+
using namespace arm_compute::experimental;
namespace
@@ -436,7 +476,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
- _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b);
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
b->mark_as_unused();
}
@@ -493,9 +533,9 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
- in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
+ multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
@@ -522,7 +562,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
}
else
{
- _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
}
}
}