aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/assembly/gemm_common.hpp')
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp24
1 files changed, 23 insertions, 1 deletions
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index ece9ca5802..834cd1061e 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021,2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -113,9 +113,17 @@ public:
{
return 0;
}
+ /* Amount of work for the threaded cases */
+ virtual size_t get_B_pretranspose_window_size() const
+ {
+ return 1;
+ }
/* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */
/* The "real" version of this depends on the templated operand type (see below). */
virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0;
+ /* Threaded version with window start/end parameters */
+ virtual void pretranspose_B_array_part_generic(void *, const void *, const int, const int, const size_t, const size_t) = 0;
+
/* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
virtual void set_pretransposed_B_data(void *)
{
@@ -225,6 +233,20 @@ public:
pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
}
+ /* Threaded versions of the above.
+ * The fallback/backwards compatible version of the threaded interface exposes a window size of 1 and
+ * just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
+ * legal values for start and end are 0 and 1 respectively. */
+ virtual void pretranspose_B_array_part(void *out, const To *in, const int row_stride, const int multi_stride, size_t, size_t)
+ {
+ pretranspose_B_array(out, in, row_stride, multi_stride);
+ };
+
+ void pretranspose_B_array_part_generic(void *out, const void *in, const int row_stride, const int multi_stride, size_t start, size_t end) override
+ {
+ pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, start, end);
+ }
+
/*** Indirect interface ***/
virtual void set_indirect_parameters(size_t, const To *const *const *)
{