aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/assembly/gemm_common.hpp')
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp18
1 files changed, 9 insertions, 9 deletions
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index f693021fcb..45d1e43274 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -189,7 +189,7 @@ public:
* 'set_arrays' to capture the provided arguments in protected class
* members, as essentially any implementation will need these.
*/
-template <typename To, typename Tw, typename Tr>
+template <typename To, typename Tr>
class GemmCommon : public IGemmCommon
{
protected:
@@ -197,7 +197,7 @@ protected:
int _lda = 0;
int _A_batch_stride = 0;
int _A_multi_stride = 0;
- const Tw *_Bptr = nullptr;
+ const To *_Bptr = nullptr;
int _ldb = 0;
int _B_multi_stride = 0;
Tr *_Cptr = nullptr;
@@ -214,7 +214,7 @@ public:
const int lda,
const int A_batch_stride,
const int A_multi_stride,
- const Tw *B,
+ const To *B,
const int ldb,
/* batches share B */ const int B_multi_stride,
Tr *C,
@@ -254,7 +254,7 @@ public:
const void *bias,
/* no row or batch stride needed */ const int bias_multi_stride) override
{
- set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const Tw *>(B), ldb,
+ set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb,
B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
static_cast<const Tr *>(bias), bias_multi_stride);
}
@@ -262,17 +262,17 @@ public:
/*** "Pretransposed" interface ***/
/* Compute col sums over all columns */
- virtual void requantize_bias(void *, const Tw *, const int, const int){};
+ virtual void requantize_bias(void *, const To *, const int, const int){};
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const Tw *, const int, const int, bool){};
+ virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){};
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
void pretranspose_B_array_generic(
void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override
{
- pretranspose_B_array(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed);
+ pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed);
}
/* Threaded versions of the above.
@@ -280,7 +280,7 @@ public:
* just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
* legal values for start and end are 0 and 1 respectively. */
virtual void pretranspose_B_array_part(
- void *out, const Tw *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
+ void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
{
pretranspose_B_array(out, in, row_stride, multi_stride, transposed);
};
@@ -293,7 +293,7 @@ public:
size_t start,
size_t end) override
{
- pretranspose_B_array_part(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed, start, end);
+ pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end);
}
/*** Indirect interface ***/