aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-14 19:03:09 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-23 12:08:12 +0000
commit48b3ef89de5f21a0169d8416e3d54081f82c7bf8 (patch)
treef857d733ccf446c704823dc7ac796a96eb55095e /arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
parent1dce3101ef8d77c8cf0af7dfd4af6595a0136b91 (diff)
downloadComputeLibrary-48b3ef89de5f21a0169d8416e3d54081f82c7bf8.tar.gz
COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels
Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/gemm_common.hpp23
1 files changed, 17 insertions, 6 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
index 1ae503cddb..d17fd5fe97 100644
--- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
@@ -48,7 +48,8 @@ public:
*/
virtual void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride,
const void *B, const int ldb, /* batches share B */ const int B_multi_stride,
- void *C, const int ldc, const int C_batch_stride, const int C_multi_stride) = 0;
+ void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+ const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0;
/* For threading, we divide the work into some number of units and work
* out internally what unit corresponds to what work. This returns the
@@ -97,7 +98,11 @@ public:
/*** "Quantized bias" interface (optional) ***/
/* Set the bias vector for quantized GEMMs */
- virtual void set_quantized_bias(const int32_t *bias) { UNUSED(bias); }
+ virtual void set_quantized_bias(const int32_t *bias, size_t bias_multi_stride)
+ {
+ UNUSED(bias);
+ UNUSED(bias_multi_stride);
+ }
// Destructor
virtual ~IGemmCommon() { }
@@ -125,13 +130,16 @@ protected:
int _ldc=0;
int _C_batch_stride=0;
int _C_multi_stride=0;
+ const Tr *_bias=nullptr;
+ int _bias_multi_stride=0;
public:
/* Pass in the pointers to the arrays to be operated on and their
* strides (templated version with appropriate types). */
virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
const To *B, const int ldb, /* batches share B */ const int B_multi_stride,
- Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) {
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+ const Tr *bias, /* no row or batch stride needed */ const int bias_multi_stride) {
_Aptr = A;
_lda = lda;
_A_batch_stride = A_batch_stride;
@@ -143,15 +151,19 @@ public:
_ldc = ldc;
_C_batch_stride = C_batch_stride;
_C_multi_stride = C_multi_stride;
+ _bias = bias;
+ _bias_multi_stride = bias_multi_stride;
}
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride,
const void *B, const int ldb, /* batches share B */ const int B_multi_stride,
- void *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override {
+ void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+ const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override {
set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride,
static_cast<const To *>(B), ldb, B_multi_stride,
- static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride);
+ static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
+ static_cast<const Tr *>(bias), bias_multi_stride);
}
/*** "Pretransposed" interface ***/
@@ -164,7 +176,6 @@ public:
void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override {
pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride);
}
-
};
} // namespace arm_gemm