aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/assembly/gemm_common.hpp')
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp41
1 files changed, 29 insertions, 12 deletions
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 6fe9f13f02..4825814e31 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021,2023 Arm Limited.
+ * Copyright (c) 2017-2021,2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP
+
#pragma once
#include "convolution_parameters.hpp"
@@ -116,6 +120,11 @@ public:
{
return false;
}
+ /* Does pretranspose accept the transposed flag? */
+ virtual bool B_pretranspose_supports_transpose() const
+ {
+ return false;
+ }
/* Total number of bytes of space needed for pretransposed arrays. */
virtual size_t get_B_pretransposed_array_size() const
{
@@ -128,10 +137,10 @@ public:
}
/* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */
/* The "real" version of this depends on the templated operand type (see below). */
- virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0;
+ virtual void pretranspose_B_array_generic(void *, const void *, const int, const int, bool) = 0;
/* Threaded version with window start/end parameters */
virtual void
- pretranspose_B_array_part_generic(void *, const void *, const int, const int, const size_t, const size_t) = 0;
+ pretranspose_B_array_part_generic(void *, const void *, const int, const int, bool, const size_t, const size_t) = 0;
/* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
virtual void set_pretransposed_B_data(void *)
@@ -251,28 +260,34 @@ public:
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const To *, const int, const int){};
+ virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){};
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
- void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override
+ void pretranspose_B_array_generic(
+ void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override
{
- pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
+ pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed);
}
/* Threaded versions of the above.
* The fallback/backwards compatible version of the threaded interface exposes a window size of 1 and
* just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
* legal values for start and end are 0 and 1 respectively. */
- virtual void
- pretranspose_B_array_part(void *out, const To *in, const int row_stride, const int multi_stride, size_t, size_t)
+ virtual void pretranspose_B_array_part(
+ void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
{
- pretranspose_B_array(out, in, row_stride, multi_stride);
+ pretranspose_B_array(out, in, row_stride, multi_stride, transposed);
};
- void pretranspose_B_array_part_generic(
- void *out, const void *in, const int row_stride, const int multi_stride, size_t start, size_t end) override
+ void pretranspose_B_array_part_generic(void *out,
+ const void *in,
+ const int row_stride,
+ const int multi_stride,
+ bool transposed,
+ size_t start,
+ size_t end) override
{
- pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, start, end);
+ pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end);
}
/*** Indirect interface ***/
@@ -287,3 +302,5 @@ public:
};
} // namespace arm_gemm
+
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_GEMM_COMMON_HPP