diff options
Diffstat (limited to 'src/cpu/kernels/assembly/gemm_common.hpp')
-rw-r--r-- | src/cpu/kernels/assembly/gemm_common.hpp | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp index ece9ca5802..834cd1061e 100644 --- a/src/cpu/kernels/assembly/gemm_common.hpp +++ b/src/cpu/kernels/assembly/gemm_common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021,2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -113,9 +113,17 @@ public: { return 0; } + /* Amount of work for the threaded cases */ + virtual size_t get_B_pretranspose_window_size() const + { + return 1; + } /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */ /* The "real" version of this depends on the templated operand type (see below). */ virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0; + /* Threaded version with window start/end parameters */ + virtual void pretranspose_B_array_part_generic(void *, const void *, const int, const int, const size_t, const size_t) = 0; + /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ virtual void set_pretransposed_B_data(void *) { @@ -225,6 +233,20 @@ public: pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride); } + /* Threaded versions of the above. + * The fallback/backwards compatible version of the threaded interface exposes a window size of 1 and + * just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only + * legal values for start and end are 0 and 1 respectively. */ + virtual void pretranspose_B_array_part(void *out, const To *in, const int row_stride, const int multi_stride, size_t, size_t) + { + pretranspose_B_array(out, in, row_stride, multi_stride); + }; + + void pretranspose_B_array_part_generic(void *out, const void *in, const int row_stride, const int multi_stride, size_t start, size_t end) override + { + pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, start, end); + } + /*** Indirect interface ***/ virtual void set_indirect_parameters(size_t, const To *const *const *) { |