diff options
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r-- | src/core/CL/cl_kernels/gemm.cl | 375 |
1 files changed, 350 insertions, 25 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 5a6efe64b9..932e0d681a 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -88,6 +88,11 @@ __kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src), * * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float) * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) + * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time: + * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D + * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor. + * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor + * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes) @@ -105,9 +110,15 @@ __kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src), * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes) * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D) */ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src), - TENSOR3D_DECLARATION(dst)) + TENSOR3D_DECLARATION(dst) +#if defined(REINTERPRET_INPUT_AS_3D) + , + uint cross_plane_pad +#endif // REINTERPRET_INPUT_AS_3D + ) { // Compute source and destination addresses uint x = get_global_id(0); @@ -124,6 +135,45 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src), // Add offset for batched GEMM dst_addr_in_bytes += z * dst_stride_z; +#if defined(REINTERPRET_INPUT_AS_3D) + __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y; + + // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D + uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D; + zin = min(DEPTH_GEMM3D - 1, zin); + + // Add offset due to the cross plane paddings + zin *= (cross_plane_pad * src_stride_y); + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply src_stride_z by DEPTH_GEMM3D + input_ptr += z * src_stride_z * DEPTH_GEMM3D; + + // Load values from Matrix A + VEC_DATA_TYPE(DATA_TYPE, 4) + a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0)); + VEC_DATA_TYPE(DATA_TYPE, 4) + a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1)); + VEC_DATA_TYPE(DATA_TYPE, 4) + a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2)); + VEC_DATA_TYPE(DATA_TYPE, 4) + a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3)); +#else // defined(REINTERPRET_INPUT_AS_3D) __global uchar *input_ptr = src.ptr; // Load values from Matrix A @@ -135,6 +185,7 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src), a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y)); VEC_DATA_TYPE(DATA_TYPE, 4) a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y)); +#endif // defined(REINTERPRET_INPUT_AS_3D) VEC_DATA_TYPE(DATA_TYPE, 4) val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0); @@ -188,7 +239,7 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src), * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), @@ -366,7 +417,7 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), @@ -679,7 +730,7 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), @@ -853,7 +904,7 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix - * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), @@ -1095,7 +1146,8 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: + * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor @@ -1122,7 +1174,8 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) + * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), @@ -1130,9 +1183,13 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), uint src0_stride_z, uint src1_stride_z, uint dst_stride_z +#if defined(REINTERPRET_INPUT_AS_3D) + , + uint src_cross_plane_pad +#endif // REINTERPRET_INPUT_AS_3D #if defined(REINTERPRET_OUTPUT_AS_3D) , - uint cross_plane_pad + uint dst_cross_plane_pad #endif // REINTERPRET_OUTPUT_AS_3D ) { @@ -1147,9 +1204,40 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), // Update address for the matrix B src_addr.s1 += idx * sizeof(DATA_TYPE); +#if defined(REINTERPRET_INPUT_AS_3D) + // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D + uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zin = min(DEPTH_GEMM3D - 1, zin); + + // Add offset due to the cross plane paddings + zin *= (src_cross_plane_pad * src0_stride_y); + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply src0_stride_z by DEPTH_GEMM3D + src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; + +#else // defined(REINTERPRET_INPUT_AS_3D) + // Add offset for batched GEMM src_addr.s0 += get_global_id(2) * src0_stride_z; +#endif // defined(REINTERPRET_INPUT_AS_3D) + #if defined(MATRIX_B_DEPTH) // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; @@ -1172,6 +1260,23 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y)) { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + VEC_DATA_TYPE(DATA_TYPE, 2) + a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + VEC_DATA_TYPE(DATA_TYPE, 2) + a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + VEC_DATA_TYPE(DATA_TYPE, 2) + a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + VEC_DATA_TYPE(DATA_TYPE, 2) + a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A VEC_DATA_TYPE(DATA_TYPE, 2) a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); @@ -1187,6 +1292,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), VEC_DATA_TYPE(DATA_TYPE, 2) a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix B VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1)); VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y)); @@ -1210,6 +1317,19 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y)) { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -1221,6 +1341,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix B VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1)); @@ -1280,7 +1402,7 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings - zout *= (cross_plane_pad * dst_stride_y); + zout *= (dst_cross_plane_pad * dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D @@ -1335,7 +1457,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: + * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor @@ -1362,7 +1485,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) + * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), @@ -1370,9 +1494,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), uint src0_stride_z, uint src1_stride_z, uint dst_stride_z +#if defined(REINTERPRET_INPUT_AS_3D) + , + uint src_cross_plane_pad +#endif // REINTERPRET_INPUT_AS_3D #if defined(REINTERPRET_OUTPUT_AS_3D) , - uint cross_plane_pad + uint dst_cross_plane_pad #endif // REINTERPRET_OUTPUT_AS_3D ) { @@ -1387,9 +1515,40 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), // Update address for matrix B src_addr.s1 += idx * sizeof(float); +#if defined(REINTERPRET_INPUT_AS_3D) + // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D + uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zin = min(DEPTH_GEMM3D - 1, zin); + + // Add offset due to the cross plane paddings + zin *= (src_cross_plane_pad * src0_stride_y); + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply src0_stride_z by DEPTH_GEMM3D + src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; + +#else // defined(REINTERPRET_INPUT_AS_3D) + // Add offset for batched GEMM src_addr.s0 += get_global_id(2) * src0_stride_z; +#endif // defined(REINTERPRET_INPUT_AS_3D) + #if defined(MATRIX_B_DEPTH) // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; @@ -1428,6 +1587,19 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), int i = 0; for(; i <= ((int)COLS_A - 4); i += 4) { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A and matrix B + float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A and matrix B float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -1439,6 +1611,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); src_addr.s1 += src1_stride_y; @@ -1579,8 +1753,21 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), for(; i < (int)COLS_A; ++i) { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A - float a0 = *((__global float *)(src0_ptr + src_addr.s0)); + float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -1590,6 +1777,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix B float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); src_addr.s1 += src1_stride_y; @@ -1676,7 +1865,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings - zout *= (cross_plane_pad * dst_stride_y); + zout *= (dst_cross_plane_pad * dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D @@ -1723,7 +1912,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: + * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor @@ -1750,7 +1940,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) + * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), @@ -1758,9 +1949,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), uint src0_stride_z, uint src1_stride_z, uint dst_stride_z +#if defined(REINTERPRET_INPUT_AS_3D) + , + uint src_cross_plane_pad +#endif // REINTERPRET_INPUT_AS_3D #if defined(REINTERPRET_OUTPUT_AS_3D) , - uint cross_plane_pad + uint dst_cross_plane_pad #endif // REINTERPRET_OUTPUT_AS_3D ) { @@ -1776,9 +1971,40 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), // Update address for the matrix B src_addr.s1 += idx * sizeof(float); +#if defined(REINTERPRET_INPUT_AS_3D) + // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D + uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zin = min(DEPTH_GEMM3D - 1, zin); + + // Add offset due to the cross plane paddings + zin *= (src_cross_plane_pad * src0_stride_y); + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply src0_stride_z by DEPTH_GEMM3D + src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; + +#else // defined(REINTERPRET_INPUT_AS_3D) + // Add offset for batched GEMM src_addr.s0 += get_global_id(2) * src0_stride_z; +#endif // defined(REINTERPRET_INPUT_AS_3D) + #if defined(MATRIX_B_DEPTH) // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; @@ -1807,8 +2033,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), int i = 0; for(; i <= ((int)COLS_A - 8); i += 8) { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0)); +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0)); +#endif // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix B float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); @@ -1848,7 +2079,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), acc01 = fma(a0.s7, b7.s1, acc01); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#if defined(REINTERPRET_INPUT_AS_3D) + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#else // defined(REINTERPRET_INPUT_AS_3D) + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // defined(REINTERPRET_INPUT_AS_3D) acc10 = fma(a0.s0, b0.s0, acc10); acc10 = fma(a0.s1, b1.s0, acc10); acc10 = fma(a0.s2, b2.s0, acc10); @@ -1868,7 +2103,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), acc11 = fma(a0.s7, b7.s1, acc11); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#if defined(REINTERPRET_INPUT_AS_3D) + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#else // defined(REINTERPRET_INPUT_AS_3D) + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // defined(REINTERPRET_INPUT_AS_3D) acc20 = fma(a0.s0, b0.s0, acc20); acc20 = fma(a0.s1, b1.s0, acc20); acc20 = fma(a0.s2, b2.s0, acc20); @@ -1888,7 +2127,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), acc21 = fma(a0.s7, b7.s1, acc21); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#if defined(REINTERPRET_INPUT_AS_3D) + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#else // defined(REINTERPRET_INPUT_AS_3D) + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // defined(REINTERPRET_INPUT_AS_3D) acc30 = fma(a0.s0, b0.s0, acc30); acc30 = fma(a0.s1, b1.s0, acc30); acc30 = fma(a0.s2, b2.s0, acc30); @@ -1913,6 +2156,19 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), // float size increment for(; i < (int)COLS_A; ++i) { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -1924,6 +2180,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix B float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); src_addr.s1 += src1_stride_y; @@ -1994,7 +2252,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings - zout *= (cross_plane_pad * dst_stride_y); + zout *= (dst_cross_plane_pad * dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D @@ -2041,7 +2299,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: + * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor @@ -2068,7 +2327,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) + * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), @@ -2076,9 +2336,13 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), uint src0_stride_z, uint src1_stride_z, uint dst_stride_z +#if defined(REINTERPRET_INPUT_AS_3D) + , + uint src_cross_plane_pad +#endif // REINTERPRET_INPUT_AS_3D #if defined(REINTERPRET_OUTPUT_AS_3D) , - uint cross_plane_pad + uint dst_cross_plane_pad #endif // REINTERPRET_OUTPUT_AS_3D ) { @@ -2093,9 +2357,40 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), // Update address for the matrix B src_addr.s1 += idx * sizeof(half); +#if defined(REINTERPRET_INPUT_AS_3D) + // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D + uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zin = min(DEPTH_GEMM3D - 1, zin); + + // Add offset due to the cross plane paddings + zin *= (src_cross_plane_pad * src0_stride_y); + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply src0_stride_z by DEPTH_GEMM3D + src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; + +#else // defined(REINTERPRET_INPUT_AS_3D) + // Add offset for batched GEMM src_addr.s0 += get_global_id(2) * src0_stride_z; +#endif // defined(REINTERPRET_INPUT_AS_3D) + #if defined(MATRIX_B_DEPTH) // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; @@ -2117,6 +2412,19 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), int i = 0; for(; i <= ((int)COLS_A - 4); i += 4) { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -2128,6 +2436,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix B half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); src_addr.s1 += src1_stride_y; @@ -2188,6 +2498,19 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), for(; i < (int)COLS_A; ++i) { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -2199,6 +2522,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix B half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); @@ -2260,7 +2585,7 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings - zout *= (cross_plane_pad * dst_stride_y); + zout *= (dst_cross_plane_pad * dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D |