From c9c62c2fa1c80ba7f11b0d0732740460dfa00e74 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 6 Apr 2018 10:00:10 +0100 Subject: COMPMID-1056 - Optimizing CLGEMMMatrixMultiplyKernel refactoring the inner loop Results reported at: https://confluence.arm.com/display/MLENG/GEMM+FP32+performance%3A+ACL+18.05 Change-Id: I3246c4f19c4d21a7d6a44e4593bc5caffc016f81 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127838 Tested-by: Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier --- src/core/CL/cl_kernels/gemm.cl | 260 ++++++++++++++++----- src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 23 +- 2 files changed, 216 insertions(+), 67 deletions(-) (limited to 'src') diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 4b1672ce7b..381130ea7f 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -342,9 +342,6 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes); __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes); - // Compute end row address for matrix B - __global float *src_end_addr_b = src_addr_b + COLS_B; - src_addr_a += offset_row_a; src_addr_b += offset_row_b; @@ -366,12 +363,18 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) float c32 = 0.0f; float c33 = 0.0f; - for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += (16 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (16 * MULT_TRANSPOSE1XW_WIDTH)) +#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH)) + + int i = 0; + for(; i <= (int)(COLS_MTX_B - 4); i += 4) { // Load values from matrix A (interleaved) and matrix B (transposed) float4 a0 = vload4(0, src_addr_a); float4 b0 = vload4(0, src_addr_b); + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; + c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); c02 = fma(a0.s0, b0.s2, c02); @@ -393,8 +396,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) c33 = fma(a0.s3, b0.s3, c33); // Load values from matrix A (interleaved) and matrix B (transposed) - a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT); - b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH); + a0 = vload4(0, src_addr_a); + b0 = vload4(0, src_addr_b); + + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); @@ -417,8 +423,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) c33 = fma(a0.s3, b0.s3, c33); // Load values from matrix A (interleaved) and matrix B (transposed) - a0 = vload4(0, src_addr_a + 8 * MULT_INTERLEAVE4X4_HEIGHT); - b0 = vload4(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH); + a0 = vload4(0, src_addr_a); + b0 = vload4(0, src_addr_b); + + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); @@ -441,8 +450,11 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) c33 = fma(a0.s3, b0.s3, c33); // Load values from matrix A (interleaved) and matrix B (transposed) - a0 = vload4(0, src_addr_a + 12 * MULT_INTERLEAVE4X4_HEIGHT); - b0 = vload4(0, src_addr_b + 12 * MULT_TRANSPOSE1XW_WIDTH); + a0 = vload4(0, src_addr_a); + b0 = vload4(0, src_addr_b); + + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); @@ -465,12 +477,15 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) c33 = fma(a0.s3, b0.s3, c33); } - for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * MULT_TRANSPOSE1XW_WIDTH)) + for(; i < (int)(COLS_MTX_B); ++i) { // Load values from matrix A (interleaved) and matrix B (transposed) float4 a0 = vload4(0, src_addr_a); float4 b0 = vload4(0, src_addr_b); + src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; + src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; + c00 = fma(a0.s0, b0.s0, c00); c01 = fma(a0.s0, b0.s1, c01); c02 = fma(a0.s0, b0.s2, c02); @@ -1130,9 +1145,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), src_addr.s1 += get_global_id(2) * src1_stride_z; #endif // defined(MATRIX_B_DEPTH) - // Address boundary for matrix A - int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float)); - // Initialize accumulators float acc00 = 0.0f; float acc01 = 0.0f; @@ -1161,72 +1173,162 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // A and B src indices get incremented at the same time. - for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y)) + int i = 0; + for(; i <= ((int)COLS_A - 4); i += 4) { - // Load values from matrix A - float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); + // Load values from matrix A and matrix B + float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); + float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); + float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); + float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - // Load values from matrix B - float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y)); - float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y)); + float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; // Multiply and accumulate acc00 = fma(a0.s0, b0.s0, acc00); - acc00 = fma(a0.s1, b1.s0, acc00); acc01 = fma(a0.s0, b0.s1, acc01); - acc01 = fma(a0.s1, b1.s1, acc01); acc02 = fma(a0.s0, b0.s2, acc02); - acc02 = fma(a0.s1, b1.s2, acc02); - acc03 = fma(a0.s1, b1.s3, acc03); acc03 = fma(a0.s0, b0.s3, acc03); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc10 = fma(a1.s0, b0.s0, acc10); acc11 = fma(a1.s0, b0.s1, acc11); acc12 = fma(a1.s0, b0.s2, acc12); acc13 = fma(a1.s0, b0.s3, acc13); - acc10 = fma(a1.s1, b1.s0, acc10); - acc11 = fma(a1.s1, b1.s1, acc11); - acc12 = fma(a1.s1, b1.s2, acc12); - acc13 = fma(a1.s1, b1.s3, acc13); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc20 = fma(a2.s0, b0.s0, acc20); acc21 = fma(a2.s0, b0.s1, acc21); acc22 = fma(a2.s0, b0.s2, acc22); acc23 = fma(a2.s0, b0.s3, acc23); - acc20 = fma(a2.s1, b1.s0, acc20); - acc21 = fma(a2.s1, b1.s1, acc21); - acc22 = fma(a2.s1, b1.s2, acc22); - acc23 = fma(a2.s1, b1.s3, acc23); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc30 = fma(a3.s0, b0.s0, acc30); acc31 = fma(a3.s0, b0.s1, acc31); acc32 = fma(a3.s0, b0.s2, acc32); acc33 = fma(a3.s0, b0.s3, acc33); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + // Load values from matrix A and matrix B + b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + + // Multiply and accumulate + acc00 = fma(a0.s1, b0.s0, acc00); + acc01 = fma(a0.s1, b0.s1, acc01); + acc02 = fma(a0.s1, b0.s2, acc02); + acc03 = fma(a0.s1, b0.s3, acc03); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + + acc10 = fma(a1.s1, b0.s0, acc10); + acc11 = fma(a1.s1, b0.s1, acc11); + acc12 = fma(a1.s1, b0.s2, acc12); + acc13 = fma(a1.s1, b0.s3, acc13); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + + acc20 = fma(a2.s1, b0.s0, acc20); + acc21 = fma(a2.s1, b0.s1, acc21); + acc22 = fma(a2.s1, b0.s2, acc22); + acc23 = fma(a2.s1, b0.s3, acc23); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 = fma(a3.s1, b1.s0, acc30); - acc31 = fma(a3.s1, b1.s1, acc31); - acc32 = fma(a3.s1, b1.s2, acc32); - acc33 = fma(a3.s1, b1.s3, acc33); + acc30 = fma(a3.s1, b0.s0, acc30); + acc31 = fma(a3.s1, b0.s1, acc31); + acc32 = fma(a3.s1, b0.s2, acc32); + acc33 = fma(a3.s1, b0.s3, acc33); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + // Load values from matrix A and matrix B + b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + + // Multiply and accumulate + acc00 = fma(a0.s2, b0.s0, acc00); + acc01 = fma(a0.s2, b0.s1, acc01); + acc02 = fma(a0.s2, b0.s2, acc02); + acc03 = fma(a0.s2, b0.s3, acc03); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + + acc10 = fma(a1.s2, b0.s0, acc10); + acc11 = fma(a1.s2, b0.s1, acc11); + acc12 = fma(a1.s2, b0.s2, acc12); + acc13 = fma(a1.s2, b0.s3, acc13); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + + acc20 = fma(a2.s2, b0.s0, acc20); + acc21 = fma(a2.s2, b0.s1, acc21); + acc22 = fma(a2.s2, b0.s2, acc22); + acc23 = fma(a2.s2, b0.s3, acc23); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + acc30 = fma(a3.s2, b0.s0, acc30); + acc31 = fma(a3.s2, b0.s1, acc31); + acc32 = fma(a3.s2, b0.s2, acc32); + acc33 = fma(a3.s2, b0.s3, acc33); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + // Load values from matrix A and matrix B + b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + + // Multiply and accumulate + acc00 = fma(a0.s3, b0.s0, acc00); + acc01 = fma(a0.s3, b0.s1, acc01); + acc02 = fma(a0.s3, b0.s2, acc02); + acc03 = fma(a0.s3, b0.s3, acc03); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + + acc10 = fma(a1.s3, b0.s0, acc10); + acc11 = fma(a1.s3, b0.s1, acc11); + acc12 = fma(a1.s3, b0.s2, acc12); + acc13 = fma(a1.s3, b0.s3, acc13); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + + acc20 = fma(a2.s3, b0.s0, acc20); + acc21 = fma(a2.s3, b0.s1, acc21); + acc22 = fma(a2.s3, b0.s2, acc22); + acc23 = fma(a2.s3, b0.s3, acc23); + +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + acc30 = fma(a3.s3, b0.s0, acc30); + acc31 = fma(a3.s3, b0.s1, acc31); + acc32 = fma(a3.s3, b0.s2, acc32); + acc33 = fma(a3.s3, b0.s3, acc33); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += 4 * sizeof(float); } - for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y)) + for(; i < (int)COLS_A; ++i) { // Load values from matrix A - float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); + float a0 = *((__global float *)(src0_ptr + src_addr.s0)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -1238,6 +1340,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // Load values from matrix B float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; // Multiply and accumulate acc00 = fma(a0, b0.s0, acc00); @@ -1262,6 +1365,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), acc32 = fma(a3, b0.s2, acc32); acc33 = fma(a3, b0.s3, acc33); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += sizeof(float); } // Compute destination address @@ -1375,9 +1480,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), src_addr.s1 += get_global_id(2) * src1_stride_z; #endif // defined(MATRIX_B_DEPTH) - // Address boundary for the matrix A - int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float)); - // Initialize accumulators float acc00 = 0.0f; float acc01 = 0.0f; @@ -1396,67 +1498,114 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // A and B src indices get incremented at the same time. - for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y)) + int i = 0; + for(; i <= ((int)COLS_A - 8); i += 8) { // Load values from matrix A - float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); + float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0)); // Load values from matrix B - float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y)); - float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y)); - float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y)); - float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y)); + float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; // Multiply and accumulate acc00 = fma(a0.s0, b0.s0, acc00); acc00 = fma(a0.s1, b1.s0, acc00); acc00 = fma(a0.s2, b2.s0, acc00); acc00 = fma(a0.s3, b3.s0, acc00); + acc00 = fma(a0.s4, b4.s0, acc00); + acc00 = fma(a0.s5, b5.s0, acc00); + acc00 = fma(a0.s6, b6.s0, acc00); + acc00 = fma(a0.s7, b7.s0, acc00); acc01 = fma(a0.s0, b0.s1, acc01); acc01 = fma(a0.s1, b1.s1, acc01); acc01 = fma(a0.s2, b2.s1, acc01); acc01 = fma(a0.s3, b3.s1, acc01); + acc01 = fma(a0.s4, b4.s1, acc01); + acc01 = fma(a0.s5, b5.s1, acc01); + acc01 = fma(a0.s6, b6.s1, acc01); + acc01 = fma(a0.s7, b7.s1, acc01); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); acc10 = fma(a0.s0, b0.s0, acc10); acc10 = fma(a0.s1, b1.s0, acc10); acc10 = fma(a0.s2, b2.s0, acc10); acc10 = fma(a0.s3, b3.s0, acc10); + acc10 = fma(a0.s4, b4.s0, acc10); + acc10 = fma(a0.s5, b5.s0, acc10); + acc10 = fma(a0.s6, b6.s0, acc10); + acc10 = fma(a0.s7, b7.s0, acc10); acc11 = fma(a0.s0, b0.s1, acc11); acc11 = fma(a0.s1, b1.s1, acc11); acc11 = fma(a0.s2, b2.s1, acc11); acc11 = fma(a0.s3, b3.s1, acc11); + acc11 = fma(a0.s4, b4.s1, acc11); + acc11 = fma(a0.s5, b5.s1, acc11); + acc11 = fma(a0.s6, b6.s1, acc11); + acc11 = fma(a0.s7, b7.s1, acc11); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); acc20 = fma(a0.s0, b0.s0, acc20); acc20 = fma(a0.s1, b1.s0, acc20); acc20 = fma(a0.s2, b2.s0, acc20); acc20 = fma(a0.s3, b3.s0, acc20); + acc20 = fma(a0.s4, b4.s0, acc20); + acc20 = fma(a0.s5, b5.s0, acc20); + acc20 = fma(a0.s6, b6.s0, acc20); + acc20 = fma(a0.s7, b7.s0, acc20); acc21 = fma(a0.s0, b0.s1, acc21); acc21 = fma(a0.s1, b1.s1, acc21); acc21 = fma(a0.s2, b2.s1, acc21); acc21 = fma(a0.s3, b3.s1, acc21); + acc21 = fma(a0.s4, b4.s1, acc21); + acc21 = fma(a0.s5, b5.s1, acc21); + acc21 = fma(a0.s6, b6.s1, acc21); + acc21 = fma(a0.s7, b7.s1, acc21); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); acc30 = fma(a0.s0, b0.s0, acc30); acc30 = fma(a0.s1, b1.s0, acc30); acc30 = fma(a0.s2, b2.s0, acc30); acc30 = fma(a0.s3, b3.s0, acc30); + acc30 = fma(a0.s4, b4.s0, acc30); + acc30 = fma(a0.s5, b5.s0, acc30); + acc30 = fma(a0.s6, b6.s0, acc30); + acc30 = fma(a0.s7, b7.s0, acc30); acc31 = fma(a0.s0, b0.s1, acc31); acc31 = fma(a0.s1, b1.s1, acc31); acc31 = fma(a0.s2, b2.s1, acc31); acc31 = fma(a0.s3, b3.s1, acc31); + acc31 = fma(a0.s4, b4.s1, acc31); + acc31 = fma(a0.s5, b5.s1, acc31); + acc31 = fma(a0.s6, b6.s1, acc31); + acc31 = fma(a0.s7, b7.s1, acc31); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += sizeof(float) * 8; } // float size increment - for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y)) + for(; i < (int)COLS_A; ++i) { // Load values from matrix A float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); @@ -1471,6 +1620,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // Load values from matrix B float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; // Multiply and accumulate acc00 = fma(a0, b0.s0, acc00); @@ -1487,6 +1637,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), acc30 = fma(a3, b0.s0, acc30); acc31 = fma(a3, b0.s1, acc31); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += sizeof(float); } // Compute destination address diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 5dd1f00fa6..9b3bf48bca 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -37,10 +37,8 @@ using namespace arm_compute::misc::shape_calculator; namespace { -Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output, bool is_interleaved_transposed) +Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output) { - const GPUTarget gpu_target = CLScheduler::get().target(); - if(is_data_type_quantized_asymmetric(input.data_type())) { // Since we need negative offsets for computing convolution, we need to change QuantizationInfo() @@ -55,7 +53,7 @@ Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const I } else { - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&input, &weights, &output, 1.f, is_interleaved_transposed, GEMMReshapeInfo(), gpu_target)); + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(&input, &weights, nullptr, &output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */))); } return Status{}; @@ -75,12 +73,12 @@ Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c } CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr memory_manager) - : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _im2col_output(), - _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr) + : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(), + _im2col_output(), _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr) { } -void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed) +void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output) { if(_is_quantized) { @@ -102,8 +100,7 @@ void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor else { // Configure matrix multiply kernel - _mm_kernel.set_target(CLScheduler::get().target()); - _mm_kernel.configure(input, weights, output, 1.f, is_interleaved_transposed); + _mm_gemm.configure(input, weights, nullptr, output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */)); } } @@ -122,7 +119,7 @@ void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLT _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false); // Configure matrix multiply kernel - configure_mm(&_im2col_output, weights, output, false); + configure_mm(&_im2col_output, weights, output); // Allocate the output tensor for im2col once all the configure methods have been called _im2col_output.allocator()->allocate(); @@ -133,7 +130,7 @@ void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTen ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1)); // Configure matrix multiply kernel - configure_mm(input, weights, output, false); + configure_mm(input, weights, output); } void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped) @@ -301,7 +298,7 @@ Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1)); } // Validate matrix multiply kernel - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output, false)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output)); // Validate output stage for asymmetric quantized types if(is_quantized) @@ -341,7 +338,7 @@ void CLFullyConnectedLayer::run() } else { - CLScheduler::get().enqueue(_mm_kernel, !_accumulate_biases); + _mm_gemm.run(); } // Accumulate biases if provided -- cgit v1.2.1