From c9c62c2fa1c80ba7f11b0d0732740460dfa00e74 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 6 Apr 2018 10:00:10 +0100 Subject: COMPMID-1056 - Optimizing CLGEMMMatrixMultiplyKernel refactoring the inner loop Results reported at: https://confluence.arm.com/display/MLENG/GEMM+FP32+performance%3A+ACL+18.05 Change-Id: I3246c4f19c4d21a7d6a44e4593bc5caffc016f81 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127838 Tested-by: Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier --- src/core/CL/cl_kernels/gemm.cl | 260 ++++++++++++++++++++++++++++++++--------- 1 file changed, 206 insertions(+), 54 deletions(-) (limited to 'src/core/CL/cl_kernels/gemm.cl') diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 4b1672ce7b..381130ea7f 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -342,9 +342,6 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes); __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes); - // Compute end row address for matrix B - __global float *src_end_addr_b = src_addr_b + COLS_B; - src_addr_a += offset_row_a; src_addr_b += offset_row_b; @@ -366,12 +363,18 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) float c32 = 0.0f; float c33 = 0.0f; - for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += (16 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (16 * MULT_TRANSPOSE1XW_WIDTH)) +#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH)) + + int i = 0; + for(; i <= (int)(COLS_MTX_B - 4); i += 4) { // Load values from matrix A (interleaved) and matrix B (transposed) float4 a0 = vload4(0, src_addr_a); float4 b0 = vload4(0, src_addr_b); + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; + c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); c02 = fma(a0.s0, b0.s2, c02); @@ -393,8 +396,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) c33 = fma(a0.s3, b0.s3, c33); // Load values from matrix A (interleaved) and matrix B (transposed) - a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT); - b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH); + a0 = vload4(0, src_addr_a); + b0 = vload4(0, src_addr_b); + + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); @@ -417,8 +423,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) c33 = fma(a0.s3, b0.s3, c33); // Load values from matrix A (interleaved) and matrix B (transposed) - a0 = vload4(0, src_addr_a + 8 * MULT_INTERLEAVE4X4_HEIGHT); - b0 = vload4(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH); + a0 = vload4(0, src_addr_a); + b0 = vload4(0, src_addr_b); + + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); @@ -441,8 +450,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) c33 = fma(a0.s3, b0.s3, c33); // Load values from matrix A (interleaved) and matrix B (transposed) - a0 = vload4(0, src_addr_a + 12 * MULT_INTERLEAVE4X4_HEIGHT); - b0 = vload4(0, src_addr_b + 12 * MULT_TRANSPOSE1XW_WIDTH); + a0 = vload4(0, src_addr_a); + b0 = vload4(0, src_addr_b); + + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); @@ -465,12 +477,15 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) c33 = fma(a0.s3, b0.s3, c33); } - for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * MULT_TRANSPOSE1XW_WIDTH)) + for(; i < (int)(COLS_MTX_B); ++i) { // Load values from matrix A (interleaved) and matrix B (transposed) float4 a0 = vload4(0, src_addr_a); float4 b0 = vload4(0, src_addr_b); + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; + c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); c02 = fma(a0.s0, b0.s2, c02); @@ -1130,9 +1145,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), src_addr.s1 += get_global_id(2) * src1_stride_z; #endif // defined(MATRIX_B_DEPTH) - // Address boundary for matrix A - int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float)); - // Initialize accumulators float acc00 = 0.0f; float acc01 = 0.0f; @@ -1161,72 +1173,162 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // A and B src indices get incremented at the same time. - for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y)) + int i = 0; + for(; i <= ((int)COLS_A - 4); i += 4) { - // Load values from matrix A - float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); + // Load values from matrix A and matrix B + float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); + float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); + float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); + float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - // Load values from matrix B - float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y)); - float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y)); + float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; // Multiply and accumulate acc00 = fma(a0.s0, b0.s0, acc00); - acc00 = fma(a0.s1, b1.s0, acc00); acc01 = fma(a0.s0, b0.s1, acc01); - acc01 = fma(a0.s1, b1.s1, acc01); acc02 = fma(a0.s0, b0.s2, acc02); - acc02 = fma(a0.s1, b1.s2, acc02); - acc03 = fma(a0.s1, b1.s3, acc03); acc03 = fma(a0.s0, b0.s3, acc03); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc10 = fma(a1.s0, b0.s0, acc10); acc11 = fma(a1.s0, b0.s1, acc11); acc12 = fma(a1.s0, b0.s2, acc12); acc13 = fma(a1.s0, b0.s3, acc13); - acc10 = fma(a1.s1, b1.s0, acc10); - acc11 = fma(a1.s1, b1.s1, acc11); - acc12 = fma(a1.s1, b1.s2, acc12); - acc13 = fma(a1.s1, b1.s3, acc13); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc20 = fma(a2.s0, b0.s0, acc20); acc21 = fma(a2.s0, b0.s1, acc21); acc22 = fma(a2.s0, b0.s2, acc22); acc23 = fma(a2.s0, b0.s3, acc23); - acc20 = fma(a2.s1, b1.s0, acc20); - acc21 = fma(a2.s1, b1.s1, acc21); - acc22 = fma(a2.s1, b1.s2, acc22); - acc23 = fma(a2.s1, b1.s3, acc23); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc30 = fma(a3.s0, b0.s0, acc30); acc31 = fma(a3.s0, b0.s1, acc31); acc32 = fma(a3.s0, b0.s2, acc32); acc33 = fma(a3.s0, b0.s3, acc33); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + // Load values from matrix A and matrix B + b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + + // Multiply and accumulate + acc00 = fma(a0.s1, b0.s0, acc00); + acc01 = fma(a0.s1, b0.s1, acc01); + acc02 = fma(a0.s1, b0.s2, acc02); + acc03 = fma(a0.s1, b0.s3, acc03); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + + acc10 = fma(a1.s1, b0.s0, acc10); + acc11 = fma(a1.s1, b0.s1, acc11); + acc12 = fma(a1.s1, b0.s2, acc12); + acc13 = fma(a1.s1, b0.s3, acc13); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + + acc20 = fma(a2.s1, b0.s0, acc20); + acc21 = fma(a2.s1, b0.s1, acc21); + acc22 = fma(a2.s1, b0.s2, acc22); + acc23 = fma(a2.s1, b0.s3, acc23); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 = fma(a3.s1, b1.s0, acc30); - acc31 = fma(a3.s1, b1.s1, acc31); - acc32 = fma(a3.s1, b1.s2, acc32); - acc33 = fma(a3.s1, b1.s3, acc33); + acc30 = fma(a3.s1, b0.s0, acc30); + acc31 = fma(a3.s1, b0.s1, acc31); + acc32 = fma(a3.s1, b0.s2, acc32); + acc33 = fma(a3.s1, b0.s3, acc33); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + // Load values from matrix A and matrix B + b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + + // Multiply and accumulate + acc00 = fma(a0.s2, b0.s0, acc00); + acc01 = fma(a0.s2, b0.s1, acc01); + acc02 = fma(a0.s2, b0.s2, acc02); + acc03 = fma(a0.s2, b0.s3, acc03); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + + acc10 = fma(a1.s2, b0.s0, acc10); + acc11 = fma(a1.s2, b0.s1, acc11); + acc12 = fma(a1.s2, b0.s2, acc12); + acc13 = fma(a1.s2, b0.s3, acc13); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + + acc20 = fma(a2.s2, b0.s0, acc20); + acc21 = fma(a2.s2, b0.s1, acc21); + acc22 = fma(a2.s2, b0.s2, acc22); + acc23 = fma(a2.s2, b0.s3, acc23); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + acc30 = fma(a3.s2, b0.s0, acc30); + acc31 = fma(a3.s2, b0.s1, acc31); + acc32 = fma(a3.s2, b0.s2, acc32); + acc33 = fma(a3.s2, b0.s3, acc33); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + // Load values from matrix A and matrix B + b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + + // Multiply and accumulate + acc00 = fma(a0.s3, b0.s0, acc00); + acc01 = fma(a0.s3, b0.s1, acc01); + acc02 = fma(a0.s3, b0.s2, acc02); + acc03 = fma(a0.s3, b0.s3, acc03); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + + acc10 = fma(a1.s3, b0.s0, acc10); + acc11 = fma(a1.s3, b0.s1, acc11); + acc12 = fma(a1.s3, b0.s2, acc12); + acc13 = fma(a1.s3, b0.s3, acc13); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + + acc20 = fma(a2.s3, b0.s0, acc20); + acc21 = fma(a2.s3, b0.s1, acc21); + acc22 = fma(a2.s3, b0.s2, acc22); + acc23 = fma(a2.s3, b0.s3, acc23); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + acc30 = fma(a3.s3, b0.s0, acc30); + acc31 = fma(a3.s3, b0.s1, acc31); + acc32 = fma(a3.s3, b0.s2, acc32); + acc33 = fma(a3.s3, b0.s3, acc33); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += 4 * sizeof(float); } - for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y)) + for(; i < (int)COLS_A; ++i) { // Load values from matrix A - float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); + float a0 = *((__global float *)(src0_ptr + src_addr.s0)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -1238,6 +1340,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // Load values from matrix B float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; // Multiply and accumulate acc00 = fma(a0, b0.s0, acc00); @@ -1262,6 +1365,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), acc32 = fma(a3, b0.s2, acc32); acc33 = fma(a3, b0.s3, acc33); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += sizeof(float); } // Compute destination address @@ -1375,9 +1480,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), src_addr.s1 += get_global_id(2) * src1_stride_z; #endif // defined(MATRIX_B_DEPTH) - // Address boundary for the matrix A - int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float)); - // Initialize accumulators float acc00 = 0.0f; float acc01 = 0.0f; @@ -1396,67 +1498,114 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // A and B src indices get incremented at the same time. - for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y)) + int i = 0; + for(; i <= ((int)COLS_A - 8); i += 8) { // Load values from matrix A - float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); + float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0)); // Load values from matrix B - float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y)); - float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y)); - float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y)); - float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y)); + float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; // Multiply and accumulate acc00 = fma(a0.s0, b0.s0, acc00); acc00 = fma(a0.s1, b1.s0, acc00); acc00 = fma(a0.s2, b2.s0, acc00); acc00 = fma(a0.s3, b3.s0, acc00); + acc00 = fma(a0.s4, b4.s0, acc00); + acc00 = fma(a0.s5, b5.s0, acc00); + acc00 = fma(a0.s6, b6.s0, acc00); + acc00 = fma(a0.s7, b7.s0, acc00); acc01 = fma(a0.s0, b0.s1, acc01); acc01 = fma(a0.s1, b1.s1, acc01); acc01 = fma(a0.s2, b2.s1, acc01); acc01 = fma(a0.s3, b3.s1, acc01); + acc01 = fma(a0.s4, b4.s1, acc01); + acc01 = fma(a0.s5, b5.s1, acc01); + acc01 = fma(a0.s6, b6.s1, acc01); + acc01 = fma(a0.s7, b7.s1, acc01); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); acc10 = fma(a0.s0, b0.s0, acc10); acc10 = fma(a0.s1, b1.s0, acc10); acc10 = fma(a0.s2, b2.s0, acc10); acc10 = fma(a0.s3, b3.s0, acc10); + acc10 = fma(a0.s4, b4.s0, acc10); + acc10 = fma(a0.s5, b5.s0, acc10); + acc10 = fma(a0.s6, b6.s0, acc10); + acc10 = fma(a0.s7, b7.s0, acc10); acc11 = fma(a0.s0, b0.s1, acc11); acc11 = fma(a0.s1, b1.s1, acc11); acc11 = fma(a0.s2, b2.s1, acc11); acc11 = fma(a0.s3, b3.s1, acc11); + acc11 = fma(a0.s4, b4.s1, acc11); + acc11 = fma(a0.s5, b5.s1, acc11); + acc11 = fma(a0.s6, b6.s1, acc11); + acc11 = fma(a0.s7, b7.s1, acc11); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); acc20 = fma(a0.s0, b0.s0, acc20); acc20 = fma(a0.s1, b1.s0, acc20); acc20 = fma(a0.s2, b2.s0, acc20); acc20 = fma(a0.s3, b3.s0, acc20); + acc20 = fma(a0.s4, b4.s0, acc20); + acc20 = fma(a0.s5, b5.s0, acc20); + acc20 = fma(a0.s6, b6.s0, acc20); + acc20 = fma(a0.s7, b7.s0, acc20); acc21 = fma(a0.s0, b0.s1, acc21); acc21 = fma(a0.s1, b1.s1, acc21); acc21 = fma(a0.s2, b2.s1, acc21); acc21 = fma(a0.s3, b3.s1, acc21); + acc21 = fma(a0.s4, b4.s1, acc21); + acc21 = fma(a0.s5, b5.s1, acc21); + acc21 = fma(a0.s6, b6.s1, acc21); + acc21 = fma(a0.s7, b7.s1, acc21); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); acc30 = fma(a0.s0, b0.s0, acc30); acc30 = fma(a0.s1, b1.s0, acc30); acc30 = fma(a0.s2, b2.s0, acc30); acc30 = fma(a0.s3, b3.s0, acc30); + acc30 = fma(a0.s4, b4.s0, acc30); + acc30 = fma(a0.s5, b5.s0, acc30); + acc30 = fma(a0.s6, b6.s0, acc30); + acc30 = fma(a0.s7, b7.s0, acc30); acc31 = fma(a0.s0, b0.s1, acc31); acc31 = fma(a0.s1, b1.s1, acc31); acc31 = fma(a0.s2, b2.s1, acc31); acc31 = fma(a0.s3, b3.s1, acc31); + acc31 = fma(a0.s4, b4.s1, acc31); + acc31 = fma(a0.s5, b5.s1, acc31); + acc31 = fma(a0.s6, b6.s1, acc31); + acc31 = fma(a0.s7, b7.s1, acc31); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += sizeof(float) * 8; } // float size increment - for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y)) + for(; i < (int)COLS_A; ++i) { // Load values from matrix A float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); @@ -1471,6 +1620,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // Load values from matrix B float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; // Multiply and accumulate acc00 = fma(a0, b0.s0, acc00); @@ -1487,6 +1637,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), acc30 = fma(a3, b0.s0, acc30); acc31 = fma(a3, b0.s1, acc31); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += sizeof(float); } // Compute destination address -- cgit v1.2.1