diff options
Diffstat (limited to 'src/cpu/kernels/assembly/gemm_common.hpp')
-rw-r--r-- | src/cpu/kernels/assembly/gemm_common.hpp | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp index f693021fcb..45d1e43274 100644 --- a/src/cpu/kernels/assembly/gemm_common.hpp +++ b/src/cpu/kernels/assembly/gemm_common.hpp @@ -189,7 +189,7 @@ public: * 'set_arrays' to capture the provided arguments in protected class * members, as essentially any implementation will need these. */ -template <typename To, typename Tw, typename Tr> +template <typename To, typename Tr> class GemmCommon : public IGemmCommon { protected: @@ -197,7 +197,7 @@ protected: int _lda = 0; int _A_batch_stride = 0; int _A_multi_stride = 0; - const Tw *_Bptr = nullptr; + const To *_Bptr = nullptr; int _ldb = 0; int _B_multi_stride = 0; Tr *_Cptr = nullptr; @@ -214,7 +214,7 @@ public: const int lda, const int A_batch_stride, const int A_multi_stride, - const Tw *B, + const To *B, const int ldb, /* batches share B */ const int B_multi_stride, Tr *C, @@ -254,7 +254,7 @@ public: const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override { - set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const Tw *>(B), ldb, + set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb, B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride, static_cast<const Tr *>(bias), bias_multi_stride); } @@ -262,17 +262,17 @@ public: /*** "Pretransposed" interface ***/ /* Compute col sums over all columns */ - virtual void requantize_bias(void *, const Tw *, const int, const int){}; + virtual void requantize_bias(void *, const To *, const int, const int){}; /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ - virtual void pretranspose_B_array(void *, const Tw *, const int, const int, bool){}; + virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){}; /* Implementation of the void * overload which casts its arguments to the appropriate type. */ void pretranspose_B_array_generic( void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override { - pretranspose_B_array(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed); + pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed); } /* Threaded versions of the above. @@ -280,7 +280,7 @@ public: * just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only * legal values for start and end are 0 and 1 respectively. */ virtual void pretranspose_B_array_part( - void *out, const Tw *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t) + void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t) { pretranspose_B_array(out, in, row_stride, multi_stride, transposed); }; @@ -293,7 +293,7 @@ public: size_t start, size_t end) override { - pretranspose_B_array_part(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed, start, end); + pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end); } /*** Indirect interface ***/ |