diff options
Diffstat (limited to 'src/core/CL/cl_kernels/gemmlowp.cl')
-rw-r--r-- | src/core/CL/cl_kernels/gemmlowp.cl | 1527 |
1 files changed, 1211 insertions, 316 deletions
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl index 80b5d00cf2..35e0d9dba5 100644 --- a/src/core/CL/cl_kernels/gemmlowp.cl +++ b/src/core/CL/cl_kernels/gemmlowp.cl @@ -26,9 +26,9 @@ #if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) #if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) -#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val = arm_dot_acc((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3), val); +#define ARM_DOT(x, y, val) val = arm_dot_acc((x), (y), (val)); #else // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) -#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val += arm_dot((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3)); +#define ARM_DOT(x, y, val) val += arm_dot((x), (y)); #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) @@ -600,29 +600,22 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION( #endif // REINTERPRET_OUTPUT_AS_3D ) { - const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP; - const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT; - const int z = get_global_id(2); - // Offset const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4; const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4; // src_addr_a = address of matrix A // src_addr_b = address of matrix B - __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes); - __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes); + __global uchar *src_addr_a = (__global uchar *)(src0_ptr + (get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT) * src0_stride_y + get_global_id(2) * src0_stride_z + src0_offset_first_element_in_bytes); + __global uchar *src_addr_b = (__global uchar *)(src1_ptr + (get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP) * src1_stride_y + src1_offset_first_element_in_bytes); #if defined(MATRIX_B_DEPTH) // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 - src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z; + src_addr_b += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; #else // defined(MATRIX_B_DEPTH) - src_addr_b += z * src1_stride_z; + src_addr_b += get_global_id(2) * src1_stride_z; #endif // defined(MATRIX_B_DEPTH) - // Compute end row address for matrix B - __global uchar *src_end_addr_b = src_addr_b + COLS_B; - src_addr_a += offset_row_a; src_addr_b += offset_row_b; @@ -631,21 +624,27 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION( uint c01 = 0; uint c02 = 0; uint c03 = 0; + uint c10 = 0; uint c11 = 0; uint c12 = 0; uint c13 = 0; + uint c20 = 0; uint c21 = 0; uint c22 = 0; uint c23 = 0; + uint c30 = 0; uint c31 = 0; uint c32 = 0; uint c33 = 0; +#define COLS_MTX_B (COLS_B / (16 * MULT_TRANSPOSE1XW_WIDTH)) + #if MULT_INTERLEAVE4X4_HEIGHT == 1 - for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP)) + int i = 0; + for(; i <= (int)(COLS_MTX_B - 8); i += 8) { // Load values from matrix A (interleaved) and matrix B (transposed) uchar16 a0 = vload16(0, src_addr_a); @@ -653,83 +652,88 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION( uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP); uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP); uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP); + uchar4 b4 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP); + uchar4 b5 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP); + uchar4 b6 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP); + uchar4 b7 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP); // Accumulate - ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s0, b1.s0, b2.s0, b3.s0, c00); - ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s1, b1.s1, b2.s1, b3.s1, c01); - ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s2, b1.s2, b2.s2, b3.s2, c02); - ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s3, b1.s3, b2.s3, b3.s3, c03); - - ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s0, b1.s0, b2.s0, b3.s0, c10); - ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s1, b1.s1, b2.s1, b3.s1, c11); - ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s2, b1.s2, b2.s2, b3.s2, c12); - ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s3, b1.s3, b2.s3, b3.s3, c13); - - ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s0, b1.s0, b2.s0, b3.s0, c20); - ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s1, b1.s1, b2.s1, b3.s1, c21); - ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s2, b1.s2, b2.s2, b3.s2, c22); - ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s3, b1.s3, b2.s3, b3.s3, c23); - - ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s0, b1.s0, b2.s0, b3.s0, c30); - ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s1, b1.s1, b2.s1, b3.s1, c31); - ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s2, b1.s2, b2.s2, b3.s2, c32); - ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s3, b1.s3, b2.s3, b3.s3, c33); - - // Load values from matrix A (interleaved) and matrix B (transposed) - a0 = vload16(0, src_addr_a + 16); - b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP); - b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP); - b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP); - b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP); + ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00); + ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01); + ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02); + ARM_DOT((uchar4)(a0.s0123), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03); + + ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10); + ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11); + ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12); + ARM_DOT((uchar4)(a0.s4567), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13); + + ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20); + ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21); + ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22); + ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23); + + ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30); + ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31); + ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32); + ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33); // Accumulate - ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s0, b1.s0, b2.s0, b3.s0, c00); - ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s1, b1.s1, b2.s1, b3.s1, c01); - ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s2, b1.s2, b2.s2, b3.s2, c02); - ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s3, b1.s3, b2.s3, b3.s3, c03); - - ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s0, b1.s0, b2.s0, b3.s0, c10); - ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s1, b1.s1, b2.s1, b3.s1, c11); - ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s2, b1.s2, b2.s2, b3.s2, c12); - ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s3, b1.s3, b2.s3, b3.s3, c13); - - ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s0, b1.s0, b2.s0, b3.s0, c20); - ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s1, b1.s1, b2.s1, b3.s1, c21); - ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s2, b1.s2, b2.s2, b3.s2, c22); - ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s3, b1.s3, b2.s3, b3.s3, c23); - - ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s0, b1.s0, b2.s0, b3.s0, c30); - ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s1, b1.s1, b2.s1, b3.s1, c31); - ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s2, b1.s2, b2.s2, b3.s2, c32); - ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s3, b1.s3, b2.s3, b3.s3, c33); - } -#endif // MULT_INTERLEAVE4X4_HEIGHT == 1 + a0 = vload16(0, src_addr_a + 16); - for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP)) - { - // Load values from matrix A (interleaved) and matrix B (transposed) - uchar4 a0 = vload4(0, src_addr_a); - uchar4 b0 = vload4(0, src_addr_b); + ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c00); + ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c01); + ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c02); + ARM_DOT((uchar4)(a0.s0123), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c03); - c00 += (ushort)a0.s0 * b0.s0; - c01 += (ushort)a0.s0 * b0.s1; - c02 += (ushort)a0.s0 * b0.s2; - c03 += (ushort)a0.s0 * b0.s3; + ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c10); + ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c11); + ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c12); + ARM_DOT((uchar4)(a0.s4567), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c13); - c10 += (ushort)a0.s1 * b0.s0; - c11 += (ushort)a0.s1 * b0.s1; - c12 += (ushort)a0.s1 * b0.s2; - c13 += (ushort)a0.s1 * b0.s3; + ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c20); + ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c21); + ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c22); + ARM_DOT((uchar4)(a0.s89AB), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c23); - c20 += (ushort)a0.s2 * b0.s0; - c21 += (ushort)a0.s2 * b0.s1; - c22 += (ushort)a0.s2 * b0.s2; - c23 += (ushort)a0.s2 * b0.s3; + ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s0, b5.s0, b6.s0, b7.s0), c30); + ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s1, b5.s1, b6.s1, b7.s1), c31); + ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s2, b5.s2, b6.s2, b7.s2), c32); + ARM_DOT((uchar4)(a0.sCDEF), (uchar4)(b4.s3, b5.s3, b6.s3, b7.s3), c33); - c30 += (ushort)a0.s3 * b0.s0; - c31 += (ushort)a0.s3 * b0.s1; - c32 += (ushort)a0.s3 * b0.s2; - c33 += (ushort)a0.s3 * b0.s3; + src_addr_a += 32; + src_addr_b += 32 * TRANSPOSE1XW_WIDTH_STEP; + } +#endif // MULT_INTERLEAVE4X4_HEIGHT == 1 + int i_left_over = 0; + for(; i < (int)(COLS_MTX_B); ++i) + { + // Load values from matrix A (interleaved) and matrix B (transposed) + uchar16 a0 = vload16(0, src_addr_a + (i_left_over % 4) + ((i_left_over / 4) * 16)); + uchar4 b0 = vload4(0, src_addr_b); + + c00 += a0.s0 * b0.s0; + c01 += a0.s0 * b0.s1; + c02 += a0.s0 * b0.s2; + c03 += a0.s0 * b0.s3; + + c10 += a0.s4 * b0.s0; + c11 += a0.s4 * b0.s1; + c12 += a0.s4 * b0.s2; + c13 += a0.s4 * b0.s3; + + c20 += a0.s8 * b0.s0; + c21 += a0.s8 * b0.s1; + c22 += a0.s8 * b0.s2; + c23 += a0.s8 * b0.s3; + + c30 += a0.sC * b0.s0; + c31 += a0.sC * b0.s1; + c32 += a0.sC * b0.s2; + c33 += a0.sC * b0.s3; + + i_left_over++; + src_addr_b += 4 * TRANSPOSE1XW_WIDTH_STEP; } // Compute destination address @@ -760,7 +764,7 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION( // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D - dst.ptr += z * dst_stride_z * DEPTH_GEMM3D; + dst.ptr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D; // Store 4x4 block vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0)); @@ -770,7 +774,7 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION( #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM - dst.ptr += z * dst_stride_z; + dst.ptr += get_global_id(2) * dst_stride_z; // Store 4x4 block vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y)); @@ -1605,6 +1609,8 @@ __kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0), // Add offset due to the cross plane paddings zin *= (src_cross_plane_pad * src0_stride_y); + zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y; + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply src0_stride_z by DEPTH_GEMM3D src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; @@ -1623,199 +1629,635 @@ __kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0), src_addr.s1 += get_global_id(2) * src1_stride_z; #endif // defined(MATRIX_B_DEPTH) - int end_row_vec_a = src_addr.s0 + COLS_A; - uint acc00 = 0; uint acc01 = 0; uint acc02 = 0; uint acc03 = 0; + uint acc04 = 0; + uint acc05 = 0; + uint acc06 = 0; + uint acc07 = 0; #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 uint acc10 = 0; uint acc11 = 0; uint acc12 = 0; uint acc13 = 0; + uint acc14 = 0; + uint acc15 = 0; + uint acc16 = 0; + uint acc17 = 0; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 uint acc20 = 0; uint acc21 = 0; uint acc22 = 0; uint acc23 = 0; + uint acc24 = 0; + uint acc25 = 0; + uint acc26 = 0; + uint acc27 = 0; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 uint acc30 = 0; uint acc31 = 0; uint acc32 = 0; uint acc33 = 0; + uint acc34 = 0; + uint acc35 = 0; + uint acc36 = 0; + uint acc37 = 0; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - uint acc40 = 0; - uint acc41 = 0; - uint acc42 = 0; - uint acc43 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y)) + // A and B src indices get incremented at the same time. + int i = 0; + for(; i <= ((int)COLS_A - 8); i += 8) { - // Load values from matrix A - uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y); +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A and matrix B + uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y); + uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y); + uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y); + uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - // Load values from matrix B - uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); - uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); - uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); - uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); +#else // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A and matrix B + uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + + uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); + uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); + uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); + uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); + src_addr.s1 += 4 * src1_stride_y; + + ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00); + ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01); + ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02); + ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03); + ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04); + ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05); + ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06); + ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07); - { - // Accumulate - ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a0.s0, a0.s1, a0.s2, a0.s3, acc00); - ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a0.s0, a0.s1, a0.s2, a0.s3, acc01); - ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a0.s0, a0.s1, a0.s2, a0.s3, acc02); - ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a0.s0, a0.s1, a0.s2, a0.s3, acc03); - } #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - { - // Accumulate - ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a1.s0, a1.s1, a1.s2, a1.s3, acc10); - ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a1.s0, a1.s1, a1.s2, a1.s3, acc11); - ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a1.s0, a1.s1, a1.s2, a1.s3, acc12); - ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a1.s0, a1.s1, a1.s2, a1.s3, acc13); - } + ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10); + ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11); + ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12); + ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13); + ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14); + ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15); + ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16); + ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - { - // Accumulate - ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a2.s0, a2.s1, a2.s2, a2.s3, acc20); - ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a2.s0, a2.s1, a2.s2, a2.s3, acc21); - ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a2.s0, a2.s1, a2.s2, a2.s3, acc22); - ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a2.s0, a2.s1, a2.s2, a2.s3, acc23); - } + ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20); + ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21); + ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22); + ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23); + ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24); + ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25); + ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26); + ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - { - // Accumulate - ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a3.s0, a3.s1, a3.s2, a3.s3, acc30); - ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a3.s0, a3.s1, a3.s2, a3.s3, acc31); - ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a3.s0, a3.s1, a3.s2, a3.s3, acc32); - ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a3.s0, a3.s1, a3.s2, a3.s3, acc33); - } + ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30); + ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31); + ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32); + ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33); + ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34); + ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35); + ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36); + ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - { - // Accumulate - ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a4.s0, a4.s1, a4.s2, a4.s3, acc40); - ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a4.s0, a4.s1, a4.s2, a4.s3, acc41); - ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a4.s0, a4.s1, a4.s2, a4.s3, acc42); - ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a4.s0, a4.s1, a4.s2, a4.s3, acc43); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + + b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); + b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); + b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); + b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); + src_addr.s1 += 4 * src1_stride_y; + + ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00); + ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01); + ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02); + ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03); + ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04); + ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05); + ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06); + ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10); + ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11); + ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12); + ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13); + ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14); + ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15); + ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16); + ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20); + ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21); + ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22); + ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23); + ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24); + ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25); + ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26); + ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30); + ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31); + ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32); + ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33); + ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34); + ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35); + ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36); + ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += 8; } - for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y)) + for(; i < (int)COLS_A; ++i) { +#if defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A - uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y); + uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y); + uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y); + uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y); + uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 +#else // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix B - uchar4 b0 = vload4(0, src1_ptr + src_addr.s1); + uchar8 b0 = vload8(0, src1_ptr + src_addr.s1); + src_addr.s1 += src1_stride_y; + + acc00 += (uint)a0 * b0.s0; + acc01 += (uint)a0 * b0.s1; + acc02 += (uint)a0 * b0.s2; + acc03 += (uint)a0 * b0.s3; + acc04 += (uint)a0 * b0.s4; + acc05 += (uint)a0 * b0.s5; + acc06 += (uint)a0 * b0.s6; + acc07 += (uint)a0 * b0.s7; - // Accumulate - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a0; - ushort tmp1 = (ushort)b0.s1 * (ushort)a0; - ushort tmp2 = (ushort)b0.s2 * (ushort)a0; - ushort tmp3 = (ushort)b0.s3 * (ushort)a0; +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc10 += (uint)a1 * b0.s0; + acc11 += (uint)a1 * b0.s1; + acc12 += (uint)a1 * b0.s2; + acc13 += (uint)a1 * b0.s3; + acc14 += (uint)a1 * b0.s4; + acc15 += (uint)a1 * b0.s5; + acc16 += (uint)a1 * b0.s6; + acc17 += (uint)a1 * b0.s7; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc20 += (uint)a2 * b0.s0; + acc21 += (uint)a2 * b0.s1; + acc22 += (uint)a2 * b0.s2; + acc23 += (uint)a2 * b0.s3; + acc24 += (uint)a2 * b0.s4; + acc25 += (uint)a2 * b0.s5; + acc26 += (uint)a2 * b0.s6; + acc27 += (uint)a2 * b0.s7; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc30 += (uint)a3 * b0.s0; + acc31 += (uint)a3 * b0.s1; + acc32 += (uint)a3 * b0.s2; + acc33 += (uint)a3 * b0.s3; + acc34 += (uint)a3 * b0.s4; + acc35 += (uint)a3 * b0.s5; + acc36 += (uint)a3 * b0.s6; + acc37 += (uint)a3 * b0.s7; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc00 += ((uint)tmp0); - acc01 += ((uint)tmp1); - acc02 += ((uint)tmp2); - acc03 += ((uint)tmp3); - } + src_addr.s0 += 1; + } + + int z = get_global_id(2); + + // Compute destination address + Image dst = CONVERT_TO_IMAGE_STRUCT(dst); + + // Compute dst address + __global uchar *dst_addr = dst.ptr; + +#if defined(REINTERPRET_OUTPUT_AS_3D) + // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D + uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); + + // Add offset due to the cross plane paddings + zout *= (dst_cross_plane_pad * dst_stride_y); + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply dst_stride_z by DEPTH_GEMM3D + dst_addr += z * dst_stride_z * DEPTH_GEMM3D; + + // Store the result + vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0)); + vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a1; - ushort tmp1 = (ushort)b0.s1 * (ushort)a1; - ushort tmp2 = (ushort)b0.s2 * (ushort)a1; - ushort tmp3 = (ushort)b0.s3 * (ushort)a1; + vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1)); + vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2)); + vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3)); + vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc10 += ((uint)tmp0); - acc11 += ((uint)tmp1); - acc12 += ((uint)tmp2); - acc13 += ((uint)tmp3); - } +#else // defined(REINTERPRET_OUTPUT_AS_3D) + // Add offset for batched GEMM + dst_addr += z * dst_stride_z; + + // Store the result + vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y)); + vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y)); + vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a2; - ushort tmp1 = (ushort)b0.s1 * (ushort)a2; - ushort tmp2 = (ushort)b0.s2 * (ushort)a2; - ushort tmp3 = (ushort)b0.s3 * (ushort)a2; + vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y)); + vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y)); + vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_OUTPUT_AS_3D) +} - acc20 += ((uint)tmp0); - acc21 += ((uint)tmp1); - acc22 += ((uint)tmp2); - acc23 += ((uint)tmp3); - } +__kernel void gemmlowp_mm_bifrost_transposed_dot8(IMAGE_DECLARATION(src0), + IMAGE_DECLARATION(src1), + IMAGE_DECLARATION(dst), + uint src0_stride_z, + uint src1_stride_z, + uint dst_stride_z +#if defined(REINTERPRET_INPUT_AS_3D) + , + uint src_cross_plane_pad +#endif // REINTERPRET_INPUT_AS_3D +#if defined(REINTERPRET_OUTPUT_AS_3D) + , + uint dst_cross_plane_pad +#endif // REINTERPRET_OUTPUT_AS_3D) + ) +{ + int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X; + + // Compute starting address for matrix A and Matrix B + int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); + + // Update address for the matrix A + src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y; + + // Update address for the matrix B + src_addr.s1 += idx; + +#if defined(REINTERPRET_INPUT_AS_3D) + // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D + uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zin = min(DEPTH_GEMM3D - 1, zin); + + // Add offset due to the cross plane paddings + zin *= (src_cross_plane_pad * src0_stride_y); + + zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y; + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply src0_stride_z by DEPTH_GEMM3D + src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; + +#else // defined(REINTERPRET_INPUT_AS_3D) + + // Add offset for batched GEMM + src_addr.s0 += get_global_id(2) * src0_stride_z; + +#endif // defined(REINTERPRET_INPUT_AS_3D) + +#if defined(MATRIX_B_DEPTH) + // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 + src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; +#else // defined(MATRIX_B_DEPTH) + src_addr.s1 += get_global_id(2) * src1_stride_z; +#endif // defined(MATRIX_B_DEPTH) + + uint acc00 = 0; + uint acc01 = 0; + uint acc02 = 0; + uint acc03 = 0; + uint acc04 = 0; + uint acc05 = 0; + uint acc06 = 0; + uint acc07 = 0; +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uint acc10 = 0; + uint acc11 = 0; + uint acc12 = 0; + uint acc13 = 0; + uint acc14 = 0; + uint acc15 = 0; + uint acc16 = 0; + uint acc17 = 0; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uint acc20 = 0; + uint acc21 = 0; + uint acc22 = 0; + uint acc23 = 0; + uint acc24 = 0; + uint acc25 = 0; + uint acc26 = 0; + uint acc27 = 0; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a3; - ushort tmp1 = (ushort)b0.s1 * (ushort)a3; - ushort tmp2 = (ushort)b0.s2 * (ushort)a3; - ushort tmp3 = (ushort)b0.s3 * (ushort)a3; + uint acc30 = 0; + uint acc31 = 0; + uint acc32 = 0; + uint acc33 = 0; + uint acc34 = 0; + uint acc35 = 0; + uint acc36 = 0; + uint acc37 = 0; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 += ((uint)tmp0); - acc31 += ((uint)tmp1); - acc32 += ((uint)tmp2); - acc33 += ((uint)tmp3); - } + // A and B src indices get incremented at the same time. + int i = 0; + for(; i <= ((int)COLS_A - 8); i += 8) + { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A and matrix B + uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a4; - ushort tmp1 = (ushort)b0.s1 * (ushort)a4; - ushort tmp2 = (ushort)b0.s2 * (ushort)a4; - ushort tmp3 = (ushort)b0.s3 * (ushort)a4; +#else // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A and matrix B + uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) - acc40 += ((uint)tmp0); - acc41 += ((uint)tmp1); - acc42 += ((uint)tmp2); - acc43 += ((uint)tmp3); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); + uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); + uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); + uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); + src_addr.s1 += 4 * src1_stride_y; + + ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00); + ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01); + ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02); + ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03); + ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04); + ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05); + ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06); + ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10); + ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11); + ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12); + ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13); + ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14); + ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15); + ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16); + ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20); + ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21); + ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22); + ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23); + ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24); + ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25); + ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26); + ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30); + ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31); + ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32); + ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33); + ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34); + ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35); + ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36); + ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); + b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); + b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); + b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); + src_addr.s1 += 4 * src1_stride_y; + + ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00); + ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01); + ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02); + ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03); + ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04); + ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05); + ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06); + ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10); + ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11); + ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12); + ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13); + ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14); + ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15); + ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16); + ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20); + ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21); + ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22); + ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23); + ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24); + ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25); + ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26); + ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30); + ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31); + ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32); + ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33); + ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34); + ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35); + ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36); + ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += 8; } - const int z = get_global_id(2); + for(; i < (int)COLS_A; ++i) + { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + + // Load values from matrix B + uchar8 b0 = vload8(0, src1_ptr + src_addr.s1); + src_addr.s1 += src1_stride_y; + + ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s0), acc00); + ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s1), acc01); + ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s2), acc02); + ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s3), acc03); + ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s4), acc04); + ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s5), acc05); + ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s6), acc06); + ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s7), acc07); + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s0), acc10); + ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s1), acc11); + ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s2), acc12); + ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s3), acc13); + ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s4), acc14); + ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s5), acc15); + ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s6), acc16); + ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s7), acc17); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s0), acc20); + ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s1), acc21); + ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s2), acc22); + ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s3), acc23); + ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s4), acc24); + ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s5), acc25); + ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s6), acc26); + ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s7), acc27); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s0), acc30); + ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s1), acc31); + ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s2), acc32); + ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s3), acc33); + ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s4), acc34); + ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s5), acc35); + ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s6), acc36); + ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s7), acc37); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += 1; + } + + int z = get_global_id(2); // Compute destination address Image dst = CONVERT_TO_IMAGE_STRUCT(dst); + // Compute dst address + __global uchar *dst_addr = dst.ptr; + #if defined(REINTERPRET_OUTPUT_AS_3D) // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings @@ -1833,7 +2275,7 @@ __kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D; + uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings @@ -1841,41 +2283,43 @@ __kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D - dst.ptr += z * dst_stride_z * DEPTH_GEMM3D; + dst_addr += z * dst_stride_z * DEPTH_GEMM3D; // Store the result - vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0)); + vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0)); + vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1)); + vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1)); + vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2)); + vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2)); + vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3)); + vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3)); + vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM - dst.ptr += z * dst_stride_z; + dst_addr += z * dst_stride_z; // Store the result - vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y)); + vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y)); + vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y)); + vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y)); + vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y)); + vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y)); + vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y)); + vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y)); + vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 #endif // defined(REINTERPRET_OUTPUT_AS_3D) } #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) @@ -1937,6 +2381,70 @@ __kernel void gemmlowp_matrix_a_reduction(TENSOR3D_DECLARATION(src), *((__global int *)dst.ptr) = (int)sum_row; } + +#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) +/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A using the arm dot product instruction + * + * @note This stage is needed to handle the offset of matrix product + * https://github.com/google/gemmlowp/blob/master/doc/low-precision.md + * + * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A + * + * @param[in] src_ptr Pointer to the source tensor. Supported data type: QASYMM8 + * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) + * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) + * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor + * @param[out] dst_ptr Pointer to the destination tensor Supported data type: S32 + * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + */ +__kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src), + IMAGE_DECLARATION(dst)) +{ + // Compute source and destination addresses + Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); + Image dst = CONVERT_TO_IMAGE_STRUCT(dst); + + uint sum_row = 0; + + __global const uchar *matrix_a = (__global const uchar *)(src.ptr + get_global_id(0) * src_stride_y + get_global_id(1) * src_stride_z); + + int i = 0; + + // This for loop performs 16 accumulations + for(; i <= ((int)COLS_A - 32); i += 32) + { + uchar16 a0_u8 = vload16(0, matrix_a + i); + + sum_row += arm_dot(a0_u8.s0123, (uchar4)(1)); + sum_row += arm_dot(a0_u8.s4567, (uchar4)(1)); + sum_row += arm_dot(a0_u8.s89AB, (uchar4)(1)); + sum_row += arm_dot(a0_u8.sCDEF, (uchar4)(1)); + + a0_u8 = vload16(1, matrix_a + i); + + sum_row += arm_dot(a0_u8.s0123, (uchar4)(1)); + sum_row += arm_dot(a0_u8.s4567, (uchar4)(1)); + sum_row += arm_dot(a0_u8.s89AB, (uchar4)(1)); + sum_row += arm_dot(a0_u8.sCDEF, (uchar4)(1)); + } + + // This for loop performs the leftover accumulations + for(; i < COLS_A; ++i) + { + sum_row += matrix_a[i]; + } + + *((__global int *)dst.ptr) = (int)sum_row; +} +#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) #endif // defined(COLS_A) #if defined(COLS_B) && defined(ROWS_B) @@ -2002,6 +2510,101 @@ __kernel void gemmlowp_matrix_b_reduction(TENSOR3D_DECLARATION(src), #endif // defined(COLS_B) && defined(ROWS_B) #if defined(K_OFFSET) + +/* Helper function used to calculate the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel. + * + * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), + * and calculates the offset contribution of matrix A and matrix B. + * + * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200) + * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1) + * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6) + * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches + * + * @param[in] x get_global_id(0) * 4 + * @param[in] y get_global_id(1) + * @param[in] z get_global_id(2) + * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr + * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr + * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr + * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes) + * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor + */ +inline int4 offset_contribution( + int x, + int y, + int z +#if defined(A_OFFSET) + , + IMAGE_DECLARATION(sum_col) +#endif // defined(A_OFFSET) +#if defined(B_OFFSET) + , + IMAGE_DECLARATION(sum_row) +#endif // defined(B_OFFSET) +#if defined(ADD_BIAS) + , + VECTOR_DECLARATION(biases) +#endif // defined(ADD_BIAS) +) +{ + int4 a_offset_s32 = (int4)0; + int4 b_offset_s32 = (int4)0; + + int batch_id = z; +#if defined(DEPTH_INPUT3D) + batch_id /= (int)DEPTH_INPUT3D; +#endif // defined(DEPTH_INPUT3D) + +#if defined(A_OFFSET) + // Compute the offset contribution due to A_OFFSET + __global uchar *sum_col_addr = sum_col_ptr + sum_col_offset_first_element_in_bytes + x * sizeof(int); + + // Compute the offset contribution due to A_OFFSET +#if defined(SUM_COL_HAS_BATCHES) + a_offset_s32 = vload4(0, (__global int *)(sum_col_addr + batch_id * sum_col_stride_y)); +#else // defined(SUM_COL_HAS_BATCHES) + a_offset_s32 = vload4(0, (__global int *)sum_col_addr); +#endif // defined(SUM_COL_HAS_BATCHES) + + a_offset_s32 *= (int4)A_OFFSET; +#endif // defined(A_OFFSET) + +#if defined(B_OFFSET) + // Compute the offset contribution due to A_OFFSET + __global uchar *sum_row_addr = sum_row_ptr + sum_row_offset_first_element_in_bytes + y * sizeof(int); + + // Compute the offset contribution due to B_OFFSET +#if defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D) + b_offset_s32 = (int4) * (((__global int *)(sum_row_addr + batch_id * sum_row_stride_y)) + (z % (int)DEPTH_INPUT3D) * (int)HEIGHT_INPUT3D); +#else // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D) + b_offset_s32 = (int4) * (((__global int *)(sum_row_addr + batch_id * sum_row_stride_y))); +#endif // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D) + b_offset_s32 *= (int4)B_OFFSET; +#endif // defined(B_OFFSET) + +#if defined(ADD_BIAS) + // Add bias + __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int); + + int4 biases_values = vload4(0, (__global int *)bias_addr); + b_offset_s32 += (int4)biases_values; +#endif // defined(ADD_BIAS) + + return (int4)K_OFFSET + a_offset_s32 + b_offset_s32; +} + /* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel. The computation is performed in-place * * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), @@ -2027,18 +2630,22 @@ __kernel void gemmlowp_matrix_b_reduction(TENSOR3D_DECLARATION(src), * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes) * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes) * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor - * @param[in] sum_col_ptr Pointer to the source tensor. Supported data type: same as @p mm_result_ptr - * @param[in] sum_col_stride_x Stride of the source tensor in X dimension (in bytes) - * @param[in] sum_col_step_x sum_col_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] sum_col_stride_y Stride of the source tensor in Y dimension (in bytes) - * @param[in] sum_col_step_y sum_col_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] sum_col_offset_first_element_in_bytes The offset of the first element in the source tensor - * @param[in] sum_row_ptr Pointer to the source tensor. Supported data type: same as @p mm_result_ptr - * @param[in] sum_row_stride_x Stride of the source tensor in X dimension (in bytes) - * @param[in] sum_row_step_x sum_row_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] sum_row_stride_y Stride of the source tensor in Y dimension (in bytes) - * @param[in] sum_row_step_y sum_row_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] sum_row_offset_first_element_in_bytes The offset of the first element in the source tensor + * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr + * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr + * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr + * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes) + * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor */ __kernel void gemmlowp_offset_contribution(TENSOR3D_DECLARATION(mm_result) #if defined(A_OFFSET) @@ -2049,56 +2656,348 @@ __kernel void gemmlowp_offset_contribution(TENSOR3D_DECLARATION(mm_result) , IMAGE_DECLARATION(sum_row) #endif // defined(B_OFFSET) +#if defined(ADD_BIAS) + , + VECTOR_DECLARATION(biases) +#endif // defined(ADD_BIAS)) ) { - Tensor3D mm_result = CONVERT_TO_TENSOR3D_STRUCT(mm_result); - + const int x = get_global_id(0) * 4; const int y = get_global_id(1); const int z = get_global_id(2); - int4 a_offset_s32 = (int4)0; - int4 b_offset_s32 = (int4)0; + // Compute offset contribution + int4 offset_term_s32 = offset_contribution( + x, y, z +#if defined(A_OFFSET) + , + sum_col_ptr, + sum_col_stride_x, + sum_col_step_x, + sum_col_stride_y, + sum_col_step_y, + sum_col_offset_first_element_in_bytes +#endif // defined(A_OFFSET) +#if defined(B_OFFSET) + , + sum_row_ptr, + sum_row_stride_x, + sum_row_step_x, + sum_row_stride_y, + sum_row_step_y, + sum_row_offset_first_element_in_bytes +#endif // defined(B_OFFSET) +#if defined(ADD_BIAS) + , + biases_ptr, + biases_stride_x, + biases_step_x, + biases_offset_first_element_in_bytes +#endif // defined(ADD_BIAS) + ); - int batch_id = z; -#if defined(DEPTH_INPUT3D) - batch_id /= (int)DEPTH_INPUT3D; -#endif // defined(DEPTH_INPUT3D) + __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z; + int4 in_s32 = vload4(0, (__global int *)mm_result_addr); + + // Add the offset terms to GEMM's result + in_s32 += offset_term_s32; + + // Store the result with the offset contribution + vstore4(in_s32, 0, (__global int *)mm_result_addr); +} + +#if defined(RESULT_OFFSET) && defined(RESULT_MULTIPLIER) && defined(RESULT_SHIFT) +/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel and it quantizes down to uint8. + * + * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), adds to it the offset contribution of matrix A and matrix B and quantizes to uint8 through the output stage. + * + * + * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200) + * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1) + * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6) + * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches + * + * The result before the output stage is: + * + * mm_result[i][k] = mm_result[i][k] + + * (sum_col[k] * A_OFFSET) + + * (sum_row[i] * B_OFFSET) + + * (K_OFFSET) + * + * This result is quantized down to uint8 using the output stage. The output stage computes the following operations: + * + * -# Add offset terms to final result + * -# Multiply each entry of result by result_mult_int + * -# Add bias to final result (if -DADD_BIAS is passed at compile time) + * -# Shift the int32 accumulator by result_shift + * -# Clamp the value between the specified min and max bounds (if -DMIN_BOUND and/or -DMAX_BOUND are passed at compile time) + * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8. + * + * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT + * + * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time + * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND. + * These values can be used to implement "rectified linear unit" activation functions + * + * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32 + * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes) + * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes) + * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor + * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr + * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr + * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr + * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes) + * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor + * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8 + * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + */ +__kernel void gemmlowp_offset_contribution_quantize_down(TENSOR3D_DECLARATION(mm_result) #if defined(A_OFFSET) - Image sum_col = CONVERT_TO_IMAGE_STRUCT(sum_col); + , + IMAGE_DECLARATION(sum_col) +#endif // defined(A_OFFSET) +#if defined(B_OFFSET) + , + IMAGE_DECLARATION(sum_row) +#endif // defined(B_OFFSET) + , +#if defined(ADD_BIAS) + VECTOR_DECLARATION(biases), +#endif // defined(ADD_BIAS) + TENSOR3D_DECLARATION(dst)) +{ + const int x = get_global_id(0) * 4; + const int y = get_global_id(1); + const int z = get_global_id(2); - // Compute the offset contribution due to A_OFFSET -#if defined(SUM_COL_HAS_BATCHES) - a_offset_s32 = vload4(0, (__global int *)(sum_col.ptr + batch_id * sum_col_stride_y)); -#else // defined(MATRIX_B_HAS_BATCHES) - a_offset_s32 = vload4(0, (__global int *)(sum_col.ptr)); -#endif // defined(MATRIX_B_HAS_BATCHES) + __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z; - a_offset_s32 *= (int4)A_OFFSET; + // Compute offset contribution + int4 offset_term_s32 = offset_contribution( + x, y, z +#if defined(A_OFFSET) + , + sum_col_ptr, + sum_col_stride_x, + sum_col_step_x, + sum_col_stride_y, + sum_col_step_y, + sum_col_offset_first_element_in_bytes #endif // defined(A_OFFSET) +#if defined(B_OFFSET) + , + sum_row_ptr, + sum_row_stride_x, + sum_row_step_x, + sum_row_stride_y, + sum_row_step_y, + sum_row_offset_first_element_in_bytes +#endif // defined(B_OFFSET) +#if defined(ADD_BIAS) + , + biases_ptr, + biases_stride_x, + biases_step_x, + biases_offset_first_element_in_bytes +#endif // defined(ADD_BIAS) + ); + + __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z; + + int4 in_s32 = vload4(0, (__global int *)mm_result_addr); + + // Add the offset terms to GEMM's result + in_s32 += offset_term_s32; + + // -------------- OUTPUT STAGE + + // Add the offset terms to GEMM's result + in_s32 += (int4)RESULT_OFFSET; + + // Multiply by result_mult_int and shift + in_s32 *= RESULT_MULTIPLIER; + in_s32 >>= RESULT_SHIFT; + + uchar4 res = convert_uchar4_sat(in_s32); + +#if defined(MIN_BOUND) + res = max(res, (uchar4)MIN_BOUND); +#endif // defined(MIN_BOUND) +#if defined(MAX_BOUND) + res = min(res, (uchar4)MAX_BOUND); +#endif // defined(MAX_BOUND) + + // Store the result + vstore4(res, 0, dst_addr); +} + +/* OpenCL kernel used to add the offset contribution after @ref CLGEMMLowpMatrixMultiplyKernel and it quantizes down to uint8. + * + * This kernel takes a final int32 accumulator value (the output of @CLGEMMLowpMatrixMultiplyKernel), adds to it the offset contribution of matrix A and matrix B and quantizes to uint8 through the output stage. + * + * + * @attention The k_offset = a_offset * b_offset * k (where k is the number of matrix A columns) needs to be passed at compile time using -DK_OFFSET (i.e. -DK_OFFSET=1200) + * @note In case the offset contribution due to a_offset is required, a_offset needs to be passed at compile time using -DA_OFFSET (i.e. -DA_OFFSET=1) + * @note In case the offset contribution due to b_offset is required, b_offset needs to be passed at compile time using -DB_OFFSET (i.e. -DB_OFFSET=6) + * @note In case sum_col has batches, -DSUM_COL_HAS_BATCHES must be passed at compile time. Usually if gemmlowp is used to accelerate convolution layer, sum_col will not have batches + * + * The result before the output stage is: + * + * mm_result[i][k] = mm_result[i][k] + + * (sum_col[k] * A_OFFSET) + + * (sum_row[i] * B_OFFSET) + + * (K_OFFSET) + * + * This result is quantized down to uint8 using the output stage. The output stage computes the following operations: + * + * -# Compute fixed point multiplication between each entry of input by result_fixedpoint_multiplier + * -# Add bias to final result if bias tensor is not a nullptr + * -# Round to nearest division by a power-of-two using result_shift + * -# Add offset to each result + * -# Clamp the value between the specified min and max bounds + * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8. + * + * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT + * + * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time + * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND. + * These values can be used to implement "rectified linear unit" activation functions + * + * @param[in] mm_result_ptr Pointer to the source tensor. Supported data type: S32 + * @param[in] mm_result_stride_x Stride of the source tensor in X dimension (in bytes) + * @param[in] mm_result_step_x mm_result_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] mm_result_stride_y Stride of the source tensor in Y dimension (in bytes) + * @param[in] mm_result_step_y mm_result_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] mm_result_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] mm_result_step_z mm_result_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] mm_result_offset_first_element_in_bytes The offset of the first element in the source tensor + * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr + * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: same as @p mm_result_ptr + * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr + * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes) + * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor + * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8 + * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + */ +__kernel void gemmlowp_offset_contribution_quantize_down_fixedpoint(TENSOR3D_DECLARATION(mm_result) +#if defined(A_OFFSET) + , + IMAGE_DECLARATION(sum_col) +#endif // defined(A_OFFSET) #if defined(B_OFFSET) - Image sum_row = CONVERT_TO_IMAGE_STRUCT(sum_row); + , + IMAGE_DECLARATION(sum_row) +#endif // defined(B_OFFSET) + , +#if defined(ADD_BIAS) + VECTOR_DECLARATION(biases), +#endif // defined(ADD_BIAS) + TENSOR3D_DECLARATION(dst)) +{ + const int x = get_global_id(0) * 4; + const int y = get_global_id(1); + const int z = get_global_id(2); - // Compute the offset contribution due to B_OFFSET -#if defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D) - b_offset_s32 = (int4) * (((__global int *)(sum_row.ptr + batch_id * sum_row_stride_y)) + (z % (int)DEPTH_INPUT3D) * (int)HEIGHT_INPUT3D + y); -#else // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D) - b_offset_s32 = (int4) * (((__global int *)(sum_row.ptr + batch_id * sum_row_stride_y)) + y); -#endif // defined(HEIGHT_INPUT3D) && defined(DEPTH_INPUT3D) - b_offset_s32 *= (int4)B_OFFSET; + // Compute offset contribution + int4 offset_term_s32 = offset_contribution( + x, y, z +#if defined(A_OFFSET) + , + sum_col_ptr, + sum_col_stride_x, + sum_col_step_x, + sum_col_stride_y, + sum_col_step_y, + sum_col_offset_first_element_in_bytes +#endif // defined(A_OFFSET) +#if defined(B_OFFSET) + , + sum_row_ptr, + sum_row_stride_x, + sum_row_step_x, + sum_row_stride_y, + sum_row_step_y, + sum_row_offset_first_element_in_bytes #endif // defined(B_OFFSET) +#if defined(ADD_BIAS) + , + biases_ptr, + biases_stride_x, + biases_step_x, + biases_offset_first_element_in_bytes +#endif // defined(ADD_BIAS) + ); - const int4 offset_term_s32 = (int4)K_OFFSET + a_offset_s32 + b_offset_s32; + __global uchar *mm_result_addr = mm_result_ptr + mm_result_offset_first_element_in_bytes + x * sizeof(int) + y * mm_result_stride_y + z * mm_result_stride_z; - int4 in_s32 = vload4(0, (__global int *)mm_result.ptr); + __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z; + + int4 in_s32 = vload4(0, (__global int *)mm_result_addr); // Add the offset terms to GEMM's result in_s32 += offset_term_s32; - // Store the result with the offset contribution - vstore4(in_s32, 0, (__global int *)mm_result.ptr); + // -------------- OUTPUT STAGE + + // Multiply by result_mult_int and shift + in_s32 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(in_s32, RESULT_MULTIPLIER, RESULT_SHIFT, 4); + + // Add the offset terms to GEMM's result + in_s32 += (int4)RESULT_OFFSET; + + uchar4 res = convert_uchar4_sat(in_s32); + +#if defined(MIN_BOUND) + res = max(res, (uchar4)MIN_BOUND); +#endif // defined(MIN_BOUND) +#if defined(MAX_BOUND) + res = min(res, (uchar4)MAX_BOUND); +#endif // defined(MAX_BOUND) + + // Store the result + vstore4(res, 0, dst_addr); } +#endif // defined(K_OFFSET) && defined(RESULT_OFFSET) && defined(RESULT_MULTIPLIER) && defined(RESULT_SHIFT) #endif // defined(K_OFFSET) #if defined(RESULT_OFFSET) && defined(RESULT_MULT_INT) && defined(RESULT_SHIFT) @@ -2128,10 +3027,10 @@ __kernel void gemmlowp_offset_contribution(TENSOR3D_DECLARATION(mm_result) * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor - * @param[in] biases_ptr Pointer to the biases tensor. Supported data type: same as @p src_ptr - * @param[in] biases_stride_x Stride of the biases tensor in X dimension (in bytes) - * @param[in] biases_step_x biases_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor + * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr + * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes) + * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) @@ -2148,39 +3047,43 @@ __kernel void gemmlowp_output_stage_quantize_down(TENSOR3D_DECLARATION(src), TENSOR3D_DECLARATION(dst)) { // Compute source and destination addresses - Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); - Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst); -#if defined(ADD_BIAS) - Vector biases = CONVERT_TO_VECTOR_STRUCT(biases); -#endif // defined(ADD_BIAS) + int x = get_global_id(0) * 4; + int y = get_global_id(1); + int z = get_global_id(2); - int16 input_values = vload16(0, (__global int *)src.ptr); + __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z; - // Add the offset terms to GEMM's result - input_values += (int16)RESULT_OFFSET; + __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z; + + int4 input_values = vload4(0, (__global int *)src_addr); #if defined(ADD_BIAS) // Add bias - const int16 biases_values = vload16(0, (__global int *)biases.ptr); - input_values += (int16)biases_values; + __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int); + + int4 biases_values = vload4(0, (__global int *)bias_addr); + input_values += (int4)biases_values; #endif // defined(ADD_BIAS) + // Add the offset terms to GEMM's result + input_values += (int4)RESULT_OFFSET; + // Multiply by result_mult_int and shift input_values *= RESULT_MULT_INT; input_values >>= RESULT_SHIFT; - uchar16 res = convert_uchar16_sat(input_values); + uchar4 res = convert_uchar4_sat(input_values); #if defined(MIN_BOUND) - res = max(res, (uchar16)MIN_BOUND); + res = max(res, (uchar4)MIN_BOUND); #endif // defined(MIN_BOUND) #if defined(MAX_BOUND) - res = min(res, (uchar16)MAX_BOUND); + res = min(res, (uchar4)MAX_BOUND); #endif // defined(MAX_BOUND) // Store the result - vstore16(res, 0, dst.ptr); + vstore4(res, 0, dst_addr); } #endif // defined(RESULT_OFFSET) && defined(RESULT_MULT_INT) && defined(RESULT_SHIFT) @@ -2197,7 +3100,7 @@ __kernel void gemmlowp_output_stage_quantize_down(TENSOR3D_DECLARATION(src), * -# Clamp the value between the specified min and max bounds * -# Clamp the resulting int32 values to the [0..255] range and cast to QASYMM8. * - * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET, -RESULT_MULT_INT and -DRESULT_SHIFT + * @attention The offset, scalar scale factor and number of bits to shift right of output tensor must be passed at compile time using -DRESULT_OFFSET_AFTER_SHIFT, -DRESULT_FIXEDPOINT_MULTIPLIER and -DRESULT_SHIFT * * @note In case the addition of int32 biases is required, -DADD_BIAS should be passed at compile time * @note In case the clamping of the result is required, the min and max bounds can be passed at compile time using -DMIN_BOUND and -DMAX_BOUND. @@ -2211,10 +3114,10 @@ __kernel void gemmlowp_output_stage_quantize_down(TENSOR3D_DECLARATION(src), * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor - * @param[in] biases_ptr Pointer to the biases tensor. Supported data type: same as @p src_ptr - * @param[in] biases_stride_x Stride of the biases tensor in X dimension (in bytes) - * @param[in] biases_step_x biases_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the biases tensor + * @param[in] biases_ptr (Optional) Pointer to the biases tensor. Supported data type: same as @p src_ptr + * @param[in] biases_stride_x (Optional) Stride of the biases tensor in X dimension (in bytes) + * @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases tensor * @param[out] dst_ptr Pointer to the destination tensor Supported data type: QASYMM8 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) @@ -2222,58 +3125,50 @@ __kernel void gemmlowp_output_stage_quantize_down(TENSOR3D_DECLARATION(src), * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] dst_stride_z Stride of the source tensor in Z dimension (in bytes) * @param[in] dst_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) - * @param[in] dst_stride_w Stride of the source tensor in W dimension (in bytes) - * @param[in] dst_step_w src_stride_w * number of elements along W processed per workitem(in bytes) * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor */ __kernel void gemmlowp_output_stage_quantize_down_fixedpoint(TENSOR3D_DECLARATION(src), #if defined(ADD_BIAS) VECTOR_DECLARATION(biases), #endif // defined(ADD_BIAS) -#if defined(DST_HEIGHT) - TENSOR4D_DECLARATION(dst)) -#else // defined(DST_HEIGHT) TENSOR3D_DECLARATION(dst)) -#endif // defined(DST_HEIGHT) { // Compute source and destination addresses - Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); -#if defined(DST_HEIGHT) - Tensor4D dst = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(dst, 1); - dst.ptr += get_global_id(0) * dst_step_x + (get_global_id(1) % DST_HEIGHT) * dst_step_y + (get_global_id(1) / DST_HEIGHT) * dst_step_z + get_global_id(2) * dst_step_w; -#else // defined(DST_HEIGHT) - Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst); -#endif // defined(DST_HEIGHT) + int x = get_global_id(0) * 4; + int y = get_global_id(1); + int z = get_global_id(2); -#if defined(ADD_BIAS) - Vector biases = CONVERT_TO_VECTOR_STRUCT(biases); -#endif // defined(ADD_BIAS) + __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(int) + y * src_stride_y + z * src_stride_z; - int16 input_values = vload16(0, (__global int *)src.ptr); + __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x + y * dst_stride_y + z * dst_stride_z; + + int4 input_values = vload4(0, (__global int *)src_addr); #if defined(ADD_BIAS) // Add bias - const int16 biases_values = vload16(0, (__global int *)biases.ptr); - input_values += (int16)biases_values; + __global uchar *bias_addr = biases_ptr + biases_offset_first_element_in_bytes + x * sizeof(int); + + int4 biases_values = vload4(0, (__global int *)bias_addr); + input_values += (int4)biases_values; #endif // defined(ADD_BIAS) // Multiply by result_mult_int and shift - input_values = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(input_values, RESULT_FIXEDPOINT_MULTIPLIER, RESULT_SHIFT, 16); + input_values = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(input_values, RESULT_FIXEDPOINT_MULTIPLIER, RESULT_SHIFT, 4); // Add the offset terms to GEMM's result - input_values += (int16)RESULT_OFFSET_AFTER_SHIFT; + input_values += (int4)RESULT_OFFSET_AFTER_SHIFT; - uchar16 res = convert_uchar16_sat(input_values); + uchar4 res = convert_uchar4_sat(input_values); #if defined(MIN_BOUND) - res = max(res, (uchar16)MIN_BOUND); + res = max(res, (uchar4)MIN_BOUND); #endif // defined(MIN_BOUND) #if defined(MAX_BOUND) - res = min(res, (uchar16)MAX_BOUND); + res = min(res, (uchar4)MAX_BOUND); #endif // defined(MAX_BOUND) // Store the result - vstore16(res, 0, dst.ptr); + vstore4(res, 0, dst_addr); } #endif // defined(RESULT_OFFSET_AFTER_SHIFT) && defined(RESULT_FIXEDPOINT_MULTIPLIER) && defined(RESULT_SHIFT) |