aboutsummaryrefslogtreecommitdiff
path: root/src/cpu
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu')
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp24
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp50
2 files changed, 68 insertions, 6 deletions
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index ece9ca5802..834cd1061e 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021,2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -113,9 +113,17 @@ public:
{
return 0;
}
+ /* Amount of work for the threaded cases */
+ virtual size_t get_B_pretranspose_window_size() const
+ {
+ return 1;
+ }
/* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */
/* The "real" version of this depends on the templated operand type (see below). */
virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0;
+ /* Threaded version with window start/end parameters */
+ virtual void pretranspose_B_array_part_generic(void *, const void *, const int, const int, const size_t, const size_t) = 0;
+
/* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
virtual void set_pretransposed_B_data(void *)
{
@@ -225,6 +233,20 @@ public:
pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
}
+ /* Threaded versions of the above.
+ * The fallback/backwards compatible version of the threaded interface exposes a window size of 1 and
+ * just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
+ * legal values for start and end are 0 and 1 respectively. */
+ virtual void pretranspose_B_array_part(void *out, const To *in, const int row_stride, const int multi_stride, size_t, size_t)
+ {
+ pretranspose_B_array(out, in, row_stride, multi_stride);
+ };
+
+ void pretranspose_B_array_part_generic(void *out, const void *in, const int row_stride, const int multi_stride, size_t start, size_t end) override
+ {
+ pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, start, end);
+ }
+
/*** Indirect interface ***/
virtual void set_indirect_parameters(size_t, const To *const *const *)
{
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 9af98be41d..9c85631406 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -38,6 +38,46 @@ namespace arm_compute
{
namespace cpu
{
+namespace
+{
+/** Run pretranspose_B_array in parallel (1D static scheduling)
+ *
+ * @tparam TypeInput
+ * @tparam TypeOutput
+ *
+ * @param[in] gemm_asm GemmCommon kernel to run
+ * @param[in] dst Pretransposed B array
+ * @param[in] src B array to be pretransposed
+ * @param[in] src_ld Stride in y
+ * @param[in] src_multi_stride Stride in z ("multi")
+ * @param[in] num_threads Number of threads to run this method. Must be >= 1
+ */
+template <typename TypeInput, typename TypeOutput>
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, ITensor *dst, const TypeInput *src, int src_ld, int src_multi_stride, unsigned int num_threads)
+{
+ ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
+ ARM_COMPUTE_ERROR_ON(num_threads == 0);
+ // The window size is also the total workload size
+ const unsigned int wsize = gemm_asm->get_B_pretranspose_window_size();
+
+ std::vector<IScheduler::Workload> workloads(num_threads);
+ for(unsigned int t = 0; t < num_threads; ++t)
+ {
+ workloads[t] = [ = ](const ThreadInfo & info)
+ {
+ const unsigned int start = (info.thread_id * wsize) / num_threads;
+ const unsigned int end = ((info.thread_id + 1) * wsize) / num_threads;
+
+ if(start < end)
+ {
+ gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end);
+ }
+ };
+ }
+ NEScheduler::get().run_tagged_workloads(workloads, "CpuGemmAssemblyDispatch/pretranspose_B_array");
+}
+} // namespace
+
using namespace arm_compute::experimental;
namespace
@@ -436,7 +476,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
- _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b);
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
b->mark_as_unused();
}
@@ -493,9 +533,9 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
- in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
+ multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
@@ -522,7 +562,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
}
else
{
- _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
}
}
}