aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemmlowp.cl
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-11-09 11:05:46 +0000
committerGian Marco Iodice <gianmarco.iodice@arm.com>2018-11-09 11:05:46 +0000
commit742df6cddab3b97281ce73e3c53df2793bb04c15 (patch)
tree4714d252754c003793502d0529693bd7d8512765 /src/core/CL/cl_kernels/gemmlowp.cl
parent027ce5b7234f694cb9047ac970aa1457d683c471 (diff)
downloadComputeLibrary-742df6cddab3b97281ce73e3c53df2793bb04c15.tar.gz
COMPMID-1451 - Removed unused OpenCL kernel from gemmlowp.cl
Removed gemmlowp_mm_bifrost_transposed_dot8 kernel as not used Change-Id: I43cf463a3a4c0cdb2808621c534ffd5c9fd47ca1
Diffstat (limited to 'src/core/CL/cl_kernels/gemmlowp.cl')
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl383
1 files changed, 0 insertions, 383 deletions
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index f2467b721a..8c1fa548e4 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -1940,390 +1940,7 @@ __kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
-
-__kernel void gemmlowp_mm_bifrost_transposed_dot8(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z
-#if defined(REINTERPRET_INPUT_AS_3D)
- ,
- uint src_cross_plane_pad
-#endif // REINTERPRET_INPUT_AS_3D
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- ,
- uint dst_cross_plane_pad
-#endif // REINTERPRET_OUTPUT_AS_3D)
- )
-{
- int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
-
- // Compute starting address for matrix A and Matrix B
- int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
-
- // Update address for the matrix A
- src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
-
- // Update address for the matrix B
- src_addr.s1 += idx;
-
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
- zin = min(DEPTH_GEMM3D - 1, zin);
-
- // Add offset due to the cross plane paddings
- zin *= (src_cross_plane_pad * src0_stride_y);
-
- zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y;
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply src0_stride_z by DEPTH_GEMM3D
- src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
-
-#else // defined(REINTERPRET_INPUT_AS_3D)
-
- // Add offset for batched GEMM
- src_addr.s0 += get_global_id(2) * src0_stride_z;
-
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr.s1 += get_global_id(2) * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- uint acc00 = 0;
- uint acc01 = 0;
- uint acc02 = 0;
- uint acc03 = 0;
- uint acc04 = 0;
- uint acc05 = 0;
- uint acc06 = 0;
- uint acc07 = 0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uint acc10 = 0;
- uint acc11 = 0;
- uint acc12 = 0;
- uint acc13 = 0;
- uint acc14 = 0;
- uint acc15 = 0;
- uint acc16 = 0;
- uint acc17 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uint acc20 = 0;
- uint acc21 = 0;
- uint acc22 = 0;
- uint acc23 = 0;
- uint acc24 = 0;
- uint acc25 = 0;
- uint acc26 = 0;
- uint acc27 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uint acc30 = 0;
- uint acc31 = 0;
- uint acc32 = 0;
- uint acc33 = 0;
- uint acc34 = 0;
- uint acc35 = 0;
- uint acc36 = 0;
- uint acc37 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- // A and B src indices get incremented at the same time.
- int i = 0;
- for(; i <= ((int)COLS_A - 8); i += 8)
- {
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A and matrix B
- uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A and matrix B
- uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
- uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
- uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
- uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
- uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
- src_addr.s1 += 4 * src1_stride_y;
-
- ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
- ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
- ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
- ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
- ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
- ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
- ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
- ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
- ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
- ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
- ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
- ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
- ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
- ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
- ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
- ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
- ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
- ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
- ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
- ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
- ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
- ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
- ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
- ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
- ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
- ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
- ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
- ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
- ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
- b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
- b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
- b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
- src_addr.s1 += 4 * src1_stride_y;
-
- ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
- ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
- ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
- ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
- ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
- ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
- ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
- ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
- ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
- ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
- ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
- ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
- ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
- ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
- ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
- ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
- ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
- ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
- ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
- ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
- ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
- ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
- ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
- ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
- ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
- ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
- ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
- ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
- ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- src_addr.s0 += 8;
- }
-
- for(; i < (int)COLS_A; ++i)
- {
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A
- uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A
- uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
- // Load values from matrix B
- uchar8 b0 = vload8(0, src1_ptr + src_addr.s1);
- src_addr.s1 += src1_stride_y;
-
- ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s0), acc00);
- ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s1), acc01);
- ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s2), acc02);
- ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s3), acc03);
- ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s4), acc04);
- ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s5), acc05);
- ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s6), acc06);
- ARM_DOT((uchar4)(a0, 0, 0, 0), (uchar4)(b0.s7), acc07);
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s0), acc10);
- ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s1), acc11);
- ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s2), acc12);
- ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s3), acc13);
- ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s4), acc14);
- ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s5), acc15);
- ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s6), acc16);
- ARM_DOT((uchar4)(a1, 0, 0, 0), (uchar4)(b0.s7), acc17);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s0), acc20);
- ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s1), acc21);
- ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s2), acc22);
- ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s3), acc23);
- ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s4), acc24);
- ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s5), acc25);
- ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s6), acc26);
- ARM_DOT((uchar4)(a2, 0, 0, 0), (uchar4)(b0.s7), acc27);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s0), acc30);
- ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s1), acc31);
- ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s2), acc32);
- ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s3), acc33);
- ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s4), acc34);
- ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s5), acc35);
- ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s6), acc36);
- ARM_DOT((uchar4)(a3, 0, 0, 0), (uchar4)(b0.s7), acc37);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- src_addr.s0 += 1;
- }
-
- int z = get_global_id(2);
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Compute dst address
- __global uchar *dst_addr = dst.ptr;
-
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
- zout = min(DEPTH_GEMM3D - 1, zout);
-
- // Add offset due to the cross plane paddings
- zout *= (dst_cross_plane_pad * dst_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply dst_stride_z by DEPTH_GEMM3D
- dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
- vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
- vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
- vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
- vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
- // Add offset for batched GEMM
- dst_addr += z * dst_stride_z;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y));
- vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y));
- vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y));
- vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
- vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
-}
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-
#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
#if defined(COLS_A)