aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemmlowp.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/gemmlowp.cl')
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl1040
1 files changed, 0 insertions, 1040 deletions
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index d6494fe380..fc90dbd16c 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -193,168 +193,6 @@
(n0, k0, a, b, c); \
})
-#if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
-/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
- * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel before running the matrix multiplication
- *
- * @note The number of matrix B columns needs to be passed at compile time using -DCOLS_B: e.g. -DCOLS_B=1024
- * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
- * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
- * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
- * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
- * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
- * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
- */
-__kernel void gemmlowp_mm_interleaved_transposed_midgard(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- ,
- uint cross_plane_pad
-#endif // REINTERPRET_OUTPUT_AS_3D
- )
-{
- const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
- const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
- const int z = get_global_id(2);
-
- // Offset
- const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
- const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
-
- // src_addr_a = address of matrix A
- // src_addr_b = address of matrix B
- __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr_b += z * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- // Compute end row address for matrix B
- __global uchar *src_end_addr_b = src_addr_b + COLS_B;
-
- src_addr_a += offset_row_a;
- src_addr_b += offset_row_b;
-
- // Reset accumulators
- int4 c00 = 0;
- int4 c10 = 0;
- int4 c20 = 0;
- int4 c30 = 0;
-
- for(; src_addr_b <= (src_end_addr_b - (int)(8 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * TRANSPOSE1XW_WIDTH_STEP)
- {
- // Load values from matrix A (interleaved) and matrix B (transposed)
- int4 a0 = convert_int4(vload4(0, src_addr_a));
- int4 b0 = convert_int4(vload4(0, src_addr_b));
-
- c00 += (int4)a0.s0 * b0;
- c10 += (int4)a0.s1 * b0;
- c20 += (int4)a0.s2 * b0;
- c30 += (int4)a0.s3 * b0;
-
- a0 = convert_int4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
- b0 = convert_int4(vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP));
-
- c00 += (int4)a0.s0 * b0;
- c10 += (int4)a0.s1 * b0;
- c20 += (int4)a0.s2 * b0;
- c30 += (int4)a0.s3 * b0;
- }
-
- for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
- {
- // Load values from matrix A (interleaved) and matrix B (transposed)
- int4 a0 = convert_int4(vload4(0, src_addr_a));
- int4 b0 = convert_int4(vload4(0, src_addr_b));
-
- c00 += (int4)a0.s0 * b0;
- c10 += (int4)a0.s1 * b0;
- c20 += (int4)a0.s2 * b0;
- c30 += (int4)a0.s3 * b0;
- }
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
- uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
- zout = min(DEPTH_GEMM3D - 1, zout);
-
- // Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply dst_stride_z by DEPTH_GEMM3D
- dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
-
- // Store 4x4 block
- vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
- vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
- vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
- vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
- // Add offset for batched GEMM
- dst.ptr += z * dst_stride_z;
-
- // Store 4x4 block
- vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
- vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
- vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
- vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
-}
-#endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
-
#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
#define VECTOR_UCHAR VEC_DATA_TYPE(uchar, NUM_ELEMS_PROCESSED_PER_THREAD_X)
#define VECTOR_UINT VEC_DATA_TYPE(uint, NUM_ELEMS_PROCESSED_PER_THREAD_X)
@@ -631,884 +469,6 @@ __kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
-
-/** OpenCL kernel optimized for Bifrost architectures that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
- *
- * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
- *
- * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
- * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
- * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
- * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
- * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
- * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
- */
-__kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z
-#if defined(REINTERPRET_INPUT_AS_3D)
- ,
- uint src_cross_plane_pad
-#endif // REINTERPRET_INPUT_AS_3D
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- ,
- uint dst_cross_plane_pad
-#endif // REINTERPRET_OUTPUT_AS_3D
- )
-{
- int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
-
- // Compute starting address for matrix A and Matrix B
- int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
-
- // Update address for the matrix A
- src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
-
- // Update address for the matrix B
- src_addr.s1 += idx;
-
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
- zin = min(DEPTH_GEMM3D - 1, zin);
-
- // Add offset due to the cross plane paddings
- zin *= (src_cross_plane_pad * src0_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply src0_stride_z by DEPTH_GEMM3D
- src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
-
-#else // defined(REINTERPRET_INPUT_AS_3D)
-
- // Add offset for batched GEMM
- src_addr.s0 += get_global_id(2) * src0_stride_z;
-
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr.s1 += get_global_id(2) * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- int end_row_vec_a = src_addr.s0 + COLS_A;
-
- uint acc00 = 0;
- uint acc01 = 0;
- uint acc02 = 0;
- uint acc03 = 0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uint acc10 = 0;
- uint acc11 = 0;
- uint acc12 = 0;
- uint acc13 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uint acc20 = 0;
- uint acc21 = 0;
- uint acc22 = 0;
- uint acc23 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uint acc30 = 0;
- uint acc31 = 0;
- uint acc32 = 0;
- uint acc33 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- uint acc40 = 0;
- uint acc41 = 0;
- uint acc42 = 0;
- uint acc43 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
-
- for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y))
- {
- // Load values from matrix A
- uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- // Load values from matrix B
- uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
- uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
- uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
- uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
-
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a0.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a0.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a0.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a0.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a0.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a0.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a0.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a0.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a0.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a0.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a0.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a0.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a0.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a0.s3;
-
- acc00 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a1.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a1.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a1.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a1.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a1.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a1.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a1.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a1.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a1.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a1.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a1.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a1.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a1.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a1.s3;
-
- acc10 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a2.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a2.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a2.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a2.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a2.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a2.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a2.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a2.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a2.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a2.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a2.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a2.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a2.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a2.s3;
-
- acc20 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a3.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a3.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a3.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a3.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a3.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a3.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a3.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a3.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a3.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a3.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a3.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a3.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a3.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a3.s3;
-
- acc30 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a4.s0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a4.s0;
-
- ushort tmp4 = (ushort)b1.s0 * (ushort)a4.s1;
- ushort tmp5 = (ushort)b1.s1 * (ushort)a4.s1;
- ushort tmp6 = (ushort)b1.s2 * (ushort)a4.s1;
- ushort tmp7 = (ushort)b1.s3 * (ushort)a4.s1;
-
- ushort tmp8 = (ushort)b2.s0 * (ushort)a4.s2;
- ushort tmp9 = (ushort)b2.s1 * (ushort)a4.s2;
- ushort tmpA = (ushort)b2.s2 * (ushort)a4.s2;
- ushort tmpB = (ushort)b2.s3 * (ushort)a4.s2;
-
- ushort tmpC = (ushort)b3.s0 * (ushort)a4.s3;
- ushort tmpD = (ushort)b3.s1 * (ushort)a4.s3;
- ushort tmpE = (ushort)b3.s2 * (ushort)a4.s3;
- ushort tmpF = (ushort)b3.s3 * (ushort)a4.s3;
-
- acc40 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
- acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
- acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
- acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- }
-
- for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
- {
- // Load values from matrix A
- uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- // Load values from matrix B
- uchar4 b0 = vload4(0, src1_ptr + src_addr.s1);
-
- // Accumulate
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a0;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a0;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a0;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a0;
-
- acc00 += ((uint)tmp0);
- acc01 += ((uint)tmp1);
- acc02 += ((uint)tmp2);
- acc03 += ((uint)tmp3);
- }
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a1;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a1;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a1;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a1;
-
- acc10 += ((uint)tmp0);
- acc11 += ((uint)tmp1);
- acc12 += ((uint)tmp2);
- acc13 += ((uint)tmp3);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a2;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a2;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a2;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a2;
-
- acc20 += ((uint)tmp0);
- acc21 += ((uint)tmp1);
- acc22 += ((uint)tmp2);
- acc23 += ((uint)tmp3);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a3;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a3;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a3;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a3;
-
- acc30 += ((uint)tmp0);
- acc31 += ((uint)tmp1);
- acc32 += ((uint)tmp2);
- acc33 += ((uint)tmp3);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- {
- // Accumulate
- ushort tmp0 = (ushort)b0.s0 * (ushort)a4;
- ushort tmp1 = (ushort)b0.s1 * (ushort)a4;
- ushort tmp2 = (ushort)b0.s2 * (ushort)a4;
- ushort tmp3 = (ushort)b0.s3 * (ushort)a4;
-
- acc40 += ((uint)tmp0);
- acc41 += ((uint)tmp1);
- acc42 += ((uint)tmp2);
- acc43 += ((uint)tmp3);
- }
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- }
-
- const int z = get_global_id(2);
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D;
- zout = min(DEPTH_GEMM3D - 1, zout);
-
- // Add offset due to the cross plane paddings
- zout *= (dst_cross_plane_pad * dst_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply dst_stride_z by DEPTH_GEMM3D
- dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
- // Add offset for batched GEMM
- dst.ptr += z * dst_stride_z;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
-}
-
-#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
- *
- * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
- *
- * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
- * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
- * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
- * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
- * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
- * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
- *
- * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
- * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
- * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
- * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
- * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
- */
-__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst),
- uint src0_stride_z,
- uint src1_stride_z,
- uint dst_stride_z
-#if defined(REINTERPRET_INPUT_AS_3D)
- ,
- uint src_cross_plane_pad
-#endif // REINTERPRET_INPUT_AS_3D
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- ,
- uint dst_cross_plane_pad
-#endif // REINTERPRET_OUTPUT_AS_3D)
- )
-{
- int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
-
- // Compute starting address for matrix A and Matrix B
- int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
-
- // Update address for the matrix A
- src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
-
- // Update address for the matrix B
- src_addr.s1 += idx;
-
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
- zin = min(DEPTH_GEMM3D - 1, zin);
-
- // Add offset due to the cross plane paddings
- zin *= (src_cross_plane_pad * src0_stride_y);
-
- zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y;
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply src0_stride_z by DEPTH_GEMM3D
- src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
-
-#else // defined(REINTERPRET_INPUT_AS_3D)
-
- // Add offset for batched GEMM
- src_addr.s0 += get_global_id(2) * src0_stride_z;
-
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
-#if defined(MATRIX_B_DEPTH)
- // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
- src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
-#else // defined(MATRIX_B_DEPTH)
- src_addr.s1 += get_global_id(2) * src1_stride_z;
-#endif // defined(MATRIX_B_DEPTH)
-
- uint acc00 = 0;
- uint acc01 = 0;
- uint acc02 = 0;
- uint acc03 = 0;
- uint acc04 = 0;
- uint acc05 = 0;
- uint acc06 = 0;
- uint acc07 = 0;
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uint acc10 = 0;
- uint acc11 = 0;
- uint acc12 = 0;
- uint acc13 = 0;
- uint acc14 = 0;
- uint acc15 = 0;
- uint acc16 = 0;
- uint acc17 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uint acc20 = 0;
- uint acc21 = 0;
- uint acc22 = 0;
- uint acc23 = 0;
- uint acc24 = 0;
- uint acc25 = 0;
- uint acc26 = 0;
- uint acc27 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uint acc30 = 0;
- uint acc31 = 0;
- uint acc32 = 0;
- uint acc33 = 0;
- uint acc34 = 0;
- uint acc35 = 0;
- uint acc36 = 0;
- uint acc37 = 0;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- // A and B src indices get incremented at the same time.
- int i = 0;
- for(; i <= ((int)COLS_A - 8); i += 8)
- {
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A and matrix B
- uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A and matrix B
- uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
- uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
- uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
- uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
- uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
- src_addr.s1 += 4 * src1_stride_y;
-
- ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
- ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
- ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
- ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
- ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
- ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
- ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
- ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
- ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
- ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
- ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
- ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
- ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
- ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
- ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
- ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
- ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
- ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
- ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
- ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
- ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
- ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
- ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
- ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
- ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
- ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
- ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
- ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
- ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
- b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
- b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
- b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
- src_addr.s1 += 4 * src1_stride_y;
-
- ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00);
- ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01);
- ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02);
- ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03);
- ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04);
- ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05);
- ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06);
- ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07);
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10);
- ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11);
- ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12);
- ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13);
- ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14);
- ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15);
- ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16);
- ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20);
- ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21);
- ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22);
- ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23);
- ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24);
- ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25);
- ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26);
- ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30);
- ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31);
- ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32);
- ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33);
- ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34);
- ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35);
- ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36);
- ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37);
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- src_addr.s0 += 8;
- }
-
- for(; i < (int)COLS_A; ++i)
- {
-#if defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A
- uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
- // Load values from matrix A
- uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_INPUT_AS_3D)
-
- // Load values from matrix B
- uchar8 b0 = vload8(0, src1_ptr + src_addr.s1);
- src_addr.s1 += src1_stride_y;
-
- acc00 += (uint)a0 * b0.s0;
- acc01 += (uint)a0 * b0.s1;
- acc02 += (uint)a0 * b0.s2;
- acc03 += (uint)a0 * b0.s3;
- acc04 += (uint)a0 * b0.s4;
- acc05 += (uint)a0 * b0.s5;
- acc06 += (uint)a0 * b0.s6;
- acc07 += (uint)a0 * b0.s7;
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc10 += (uint)a1 * b0.s0;
- acc11 += (uint)a1 * b0.s1;
- acc12 += (uint)a1 * b0.s2;
- acc13 += (uint)a1 * b0.s3;
- acc14 += (uint)a1 * b0.s4;
- acc15 += (uint)a1 * b0.s5;
- acc16 += (uint)a1 * b0.s6;
- acc17 += (uint)a1 * b0.s7;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc20 += (uint)a2 * b0.s0;
- acc21 += (uint)a2 * b0.s1;
- acc22 += (uint)a2 * b0.s2;
- acc23 += (uint)a2 * b0.s3;
- acc24 += (uint)a2 * b0.s4;
- acc25 += (uint)a2 * b0.s5;
- acc26 += (uint)a2 * b0.s6;
- acc27 += (uint)a2 * b0.s7;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc30 += (uint)a3 * b0.s0;
- acc31 += (uint)a3 * b0.s1;
- acc32 += (uint)a3 * b0.s2;
- acc33 += (uint)a3 * b0.s3;
- acc34 += (uint)a3 * b0.s4;
- acc35 += (uint)a3 * b0.s5;
- acc36 += (uint)a3 * b0.s6;
- acc37 += (uint)a3 * b0.s7;
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
- src_addr.s0 += 1;
- }
-
- int z = get_global_id(2);
-
- // Compute destination address
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
-
- // Compute dst address
- __global uchar *dst_addr = dst.ptr;
-
-#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
- uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
- zout = min(DEPTH_GEMM3D - 1, zout);
-
- // Add offset due to the cross plane paddings
- zout *= (dst_cross_plane_pad * dst_stride_y);
-
- // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
- // multiply dst_stride_z by DEPTH_GEMM3D
- dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
- vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
- vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
- vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
- vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
-#else // defined(REINTERPRET_OUTPUT_AS_3D)
- // Add offset for batched GEMM
- dst_addr += z * dst_stride_z;
-
- // Store the result
- vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y));
- vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y));
- vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y));
- vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
- vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#endif // defined(REINTERPRET_OUTPUT_AS_3D)
-}
-#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(M) && defined(N)