aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-04-06 10:00:10 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commitc9c62c2fa1c80ba7f11b0d0732740460dfa00e74 (patch)
tree260052aa5c7172e2afc8517ae13adb75504ee62e /src/core/CL/cl_kernels/gemm.cl
parent3ab6804a0e30be4d8591c8c84ae6a73940d0f2e2 (diff)
downloadComputeLibrary-c9c62c2fa1c80ba7f11b0d0732740460dfa00e74.tar.gz
COMPMID-1056 - Optimizing CLGEMMMatrixMultiplyKernel refactoring the inner loop
Results reported at: https://confluence.arm.com/display/MLENG/GEMM+FP32+performance%3A+ACL+18.05 Change-Id: I3246c4f19c4d21a7d6a44e4593bc5caffc016f81 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127838 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl260
1 files changed, 206 insertions, 54 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 4b1672ce7b..381130ea7f 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -342,9 +342,6 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
__global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
__global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
- // Compute end row address for matrix B
- __global float *src_end_addr_b = src_addr_b + COLS_B;
-
src_addr_a += offset_row_a;
src_addr_b += offset_row_b;
@@ -366,12 +363,18 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
float c32 = 0.0f;
float c33 = 0.0f;
- for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += (16 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (16 * MULT_TRANSPOSE1XW_WIDTH))
+#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
+
+ int i = 0;
+ for(; i <= (int)(COLS_MTX_B - 4); i += 4)
{
// Load values from matrix A (interleaved) and matrix B (transposed)
float4 a0 = vload4(0, src_addr_a);
float4 b0 = vload4(0, src_addr_b);
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
+
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
c02 = fma(a0.s0, b0.s2, c02);
@@ -393,8 +396,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = fma(a0.s3, b0.s3, c33);
// Load values from matrix A (interleaved) and matrix B (transposed)
- a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
- b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
+ a0 = vload4(0, src_addr_a);
+ b0 = vload4(0, src_addr_b);
+
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
@@ -417,8 +423,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = fma(a0.s3, b0.s3, c33);
// Load values from matrix A (interleaved) and matrix B (transposed)
- a0 = vload4(0, src_addr_a + 8 * MULT_INTERLEAVE4X4_HEIGHT);
- b0 = vload4(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
+ a0 = vload4(0, src_addr_a);
+ b0 = vload4(0, src_addr_b);
+
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
@@ -441,8 +450,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = fma(a0.s3, b0.s3, c33);
// Load values from matrix A (interleaved) and matrix B (transposed)
- a0 = vload4(0, src_addr_a + 12 * MULT_INTERLEAVE4X4_HEIGHT);
- b0 = vload4(0, src_addr_b + 12 * MULT_TRANSPOSE1XW_WIDTH);
+ a0 = vload4(0, src_addr_a);
+ b0 = vload4(0, src_addr_b);
+
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
@@ -465,12 +477,15 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
c33 = fma(a0.s3, b0.s3, c33);
}
- for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * MULT_TRANSPOSE1XW_WIDTH))
+ for(; i < (int)(COLS_MTX_B); ++i)
{
// Load values from matrix A (interleaved) and matrix B (transposed)
float4 a0 = vload4(0, src_addr_a);
float4 b0 = vload4(0, src_addr_b);
+ src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+ src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
+
c00 = fma(a0.s0, b0.s0, c00);
c01 = fma(a0.s0, b0.s1, c01);
c02 = fma(a0.s0, b0.s2, c02);
@@ -1130,9 +1145,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
src_addr.s1 += get_global_id(2) * src1_stride_z;
#endif // defined(MATRIX_B_DEPTH)
- // Address boundary for matrix A
- int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
-
// Initialize accumulators
float acc00 = 0.0f;
float acc01 = 0.0f;
@@ -1161,72 +1173,162 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
// A and B src indices get incremented at the same time.
- for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
+ int i = 0;
+ for(; i <= ((int)COLS_A - 4); i += 4)
{
- // Load values from matrix A
- float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+ // Load values from matrix A and matrix B
+ float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+ float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+ float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+ float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- // Load values from matrix B
- float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
- float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
+ float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
// Multiply and accumulate
acc00 = fma(a0.s0, b0.s0, acc00);
- acc00 = fma(a0.s1, b1.s0, acc00);
acc01 = fma(a0.s0, b0.s1, acc01);
- acc01 = fma(a0.s1, b1.s1, acc01);
acc02 = fma(a0.s0, b0.s2, acc02);
- acc02 = fma(a0.s1, b1.s2, acc02);
- acc03 = fma(a0.s1, b1.s3, acc03);
acc03 = fma(a0.s0, b0.s3, acc03);
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
acc10 = fma(a1.s0, b0.s0, acc10);
acc11 = fma(a1.s0, b0.s1, acc11);
acc12 = fma(a1.s0, b0.s2, acc12);
acc13 = fma(a1.s0, b0.s3, acc13);
- acc10 = fma(a1.s1, b1.s0, acc10);
- acc11 = fma(a1.s1, b1.s1, acc11);
- acc12 = fma(a1.s1, b1.s2, acc12);
- acc13 = fma(a1.s1, b1.s3, acc13);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
acc20 = fma(a2.s0, b0.s0, acc20);
acc21 = fma(a2.s0, b0.s1, acc21);
acc22 = fma(a2.s0, b0.s2, acc22);
acc23 = fma(a2.s0, b0.s3, acc23);
- acc20 = fma(a2.s1, b1.s0, acc20);
- acc21 = fma(a2.s1, b1.s1, acc21);
- acc22 = fma(a2.s1, b1.s2, acc22);
- acc23 = fma(a2.s1, b1.s3, acc23);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
acc30 = fma(a3.s0, b0.s0, acc30);
acc31 = fma(a3.s0, b0.s1, acc31);
acc32 = fma(a3.s0, b0.s2, acc32);
acc33 = fma(a3.s0, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ // Load values from matrix A and matrix B
+ b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+
+ // Multiply and accumulate
+ acc00 = fma(a0.s1, b0.s0, acc00);
+ acc01 = fma(a0.s1, b0.s1, acc01);
+ acc02 = fma(a0.s1, b0.s2, acc02);
+ acc03 = fma(a0.s1, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+ acc10 = fma(a1.s1, b0.s0, acc10);
+ acc11 = fma(a1.s1, b0.s1, acc11);
+ acc12 = fma(a1.s1, b0.s2, acc12);
+ acc13 = fma(a1.s1, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+ acc20 = fma(a2.s1, b0.s0, acc20);
+ acc21 = fma(a2.s1, b0.s1, acc21);
+ acc22 = fma(a2.s1, b0.s2, acc22);
+ acc23 = fma(a2.s1, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc30 = fma(a3.s1, b1.s0, acc30);
- acc31 = fma(a3.s1, b1.s1, acc31);
- acc32 = fma(a3.s1, b1.s2, acc32);
- acc33 = fma(a3.s1, b1.s3, acc33);
+ acc30 = fma(a3.s1, b0.s0, acc30);
+ acc31 = fma(a3.s1, b0.s1, acc31);
+ acc32 = fma(a3.s1, b0.s2, acc32);
+ acc33 = fma(a3.s1, b0.s3, acc33);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ // Load values from matrix A and matrix B
+ b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+
+ // Multiply and accumulate
+ acc00 = fma(a0.s2, b0.s0, acc00);
+ acc01 = fma(a0.s2, b0.s1, acc01);
+ acc02 = fma(a0.s2, b0.s2, acc02);
+ acc03 = fma(a0.s2, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+ acc10 = fma(a1.s2, b0.s0, acc10);
+ acc11 = fma(a1.s2, b0.s1, acc11);
+ acc12 = fma(a1.s2, b0.s2, acc12);
+ acc13 = fma(a1.s2, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+ acc20 = fma(a2.s2, b0.s0, acc20);
+ acc21 = fma(a2.s2, b0.s1, acc21);
+ acc22 = fma(a2.s2, b0.s2, acc22);
+ acc23 = fma(a2.s2, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ acc30 = fma(a3.s2, b0.s0, acc30);
+ acc31 = fma(a3.s2, b0.s1, acc31);
+ acc32 = fma(a3.s2, b0.s2, acc32);
+ acc33 = fma(a3.s2, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ // Load values from matrix A and matrix B
+ b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+
+ // Multiply and accumulate
+ acc00 = fma(a0.s3, b0.s0, acc00);
+ acc01 = fma(a0.s3, b0.s1, acc01);
+ acc02 = fma(a0.s3, b0.s2, acc02);
+ acc03 = fma(a0.s3, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+ acc10 = fma(a1.s3, b0.s0, acc10);
+ acc11 = fma(a1.s3, b0.s1, acc11);
+ acc12 = fma(a1.s3, b0.s2, acc12);
+ acc13 = fma(a1.s3, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+ acc20 = fma(a2.s3, b0.s0, acc20);
+ acc21 = fma(a2.s3, b0.s1, acc21);
+ acc22 = fma(a2.s3, b0.s2, acc22);
+ acc23 = fma(a2.s3, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ acc30 = fma(a3.s3, b0.s0, acc30);
+ acc31 = fma(a3.s3, b0.s1, acc31);
+ acc32 = fma(a3.s3, b0.s2, acc32);
+ acc33 = fma(a3.s3, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += 4 * sizeof(float);
}
- for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
+ for(; i < (int)COLS_A; ++i)
{
// Load values from matrix A
- float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1238,6 +1340,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
// Load values from matrix B
float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
// Multiply and accumulate
acc00 = fma(a0, b0.s0, acc00);
@@ -1262,6 +1365,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
acc32 = fma(a3, b0.s2, acc32);
acc33 = fma(a3, b0.s3, acc33);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += sizeof(float);
}
// Compute destination address
@@ -1375,9 +1480,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
src_addr.s1 += get_global_id(2) * src1_stride_z;
#endif // defined(MATRIX_B_DEPTH)
- // Address boundary for the matrix A
- int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
-
// Initialize accumulators
float acc00 = 0.0f;
float acc01 = 0.0f;
@@ -1396,67 +1498,114 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
// A and B src indices get incremented at the same time.
- for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y))
+ int i = 0;
+ for(; i <= ((int)COLS_A - 8); i += 8)
{
// Load values from matrix A
- float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+ float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
// Load values from matrix B
- float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
- float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
- float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y));
- float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y));
+ float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
+ float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
// Multiply and accumulate
acc00 = fma(a0.s0, b0.s0, acc00);
acc00 = fma(a0.s1, b1.s0, acc00);
acc00 = fma(a0.s2, b2.s0, acc00);
acc00 = fma(a0.s3, b3.s0, acc00);
+ acc00 = fma(a0.s4, b4.s0, acc00);
+ acc00 = fma(a0.s5, b5.s0, acc00);
+ acc00 = fma(a0.s6, b6.s0, acc00);
+ acc00 = fma(a0.s7, b7.s0, acc00);
acc01 = fma(a0.s0, b0.s1, acc01);
acc01 = fma(a0.s1, b1.s1, acc01);
acc01 = fma(a0.s2, b2.s1, acc01);
acc01 = fma(a0.s3, b3.s1, acc01);
+ acc01 = fma(a0.s4, b4.s1, acc01);
+ acc01 = fma(a0.s5, b5.s1, acc01);
+ acc01 = fma(a0.s6, b6.s1, acc01);
+ acc01 = fma(a0.s7, b7.s1, acc01);
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
acc10 = fma(a0.s0, b0.s0, acc10);
acc10 = fma(a0.s1, b1.s0, acc10);
acc10 = fma(a0.s2, b2.s0, acc10);
acc10 = fma(a0.s3, b3.s0, acc10);
+ acc10 = fma(a0.s4, b4.s0, acc10);
+ acc10 = fma(a0.s5, b5.s0, acc10);
+ acc10 = fma(a0.s6, b6.s0, acc10);
+ acc10 = fma(a0.s7, b7.s0, acc10);
acc11 = fma(a0.s0, b0.s1, acc11);
acc11 = fma(a0.s1, b1.s1, acc11);
acc11 = fma(a0.s2, b2.s1, acc11);
acc11 = fma(a0.s3, b3.s1, acc11);
+ acc11 = fma(a0.s4, b4.s1, acc11);
+ acc11 = fma(a0.s5, b5.s1, acc11);
+ acc11 = fma(a0.s6, b6.s1, acc11);
+ acc11 = fma(a0.s7, b7.s1, acc11);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
acc20 = fma(a0.s0, b0.s0, acc20);
acc20 = fma(a0.s1, b1.s0, acc20);
acc20 = fma(a0.s2, b2.s0, acc20);
acc20 = fma(a0.s3, b3.s0, acc20);
+ acc20 = fma(a0.s4, b4.s0, acc20);
+ acc20 = fma(a0.s5, b5.s0, acc20);
+ acc20 = fma(a0.s6, b6.s0, acc20);
+ acc20 = fma(a0.s7, b7.s0, acc20);
acc21 = fma(a0.s0, b0.s1, acc21);
acc21 = fma(a0.s1, b1.s1, acc21);
acc21 = fma(a0.s2, b2.s1, acc21);
acc21 = fma(a0.s3, b3.s1, acc21);
+ acc21 = fma(a0.s4, b4.s1, acc21);
+ acc21 = fma(a0.s5, b5.s1, acc21);
+ acc21 = fma(a0.s6, b6.s1, acc21);
+ acc21 = fma(a0.s7, b7.s1, acc21);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
acc30 = fma(a0.s0, b0.s0, acc30);
acc30 = fma(a0.s1, b1.s0, acc30);
acc30 = fma(a0.s2, b2.s0, acc30);
acc30 = fma(a0.s3, b3.s0, acc30);
+ acc30 = fma(a0.s4, b4.s0, acc30);
+ acc30 = fma(a0.s5, b5.s0, acc30);
+ acc30 = fma(a0.s6, b6.s0, acc30);
+ acc30 = fma(a0.s7, b7.s0, acc30);
acc31 = fma(a0.s0, b0.s1, acc31);
acc31 = fma(a0.s1, b1.s1, acc31);
acc31 = fma(a0.s2, b2.s1, acc31);
acc31 = fma(a0.s3, b3.s1, acc31);
+ acc31 = fma(a0.s4, b4.s1, acc31);
+ acc31 = fma(a0.s5, b5.s1, acc31);
+ acc31 = fma(a0.s6, b6.s1, acc31);
+ acc31 = fma(a0.s7, b7.s1, acc31);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += sizeof(float) * 8;
}
// float size increment
- for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y))
+ for(; i < (int)COLS_A; ++i)
{
// Load values from matrix A
float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
@@ -1471,6 +1620,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
// Load values from matrix B
float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+ src_addr.s1 += src1_stride_y;
// Multiply and accumulate
acc00 = fma(a0, b0.s0, acc00);
@@ -1487,6 +1637,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc30 = fma(a3, b0.s0, acc30);
acc31 = fma(a3, b0.s1, acc31);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += sizeof(float);
}
// Compute destination address