From a25d16c86f0d870408bc8b941aa755093417b0f0 Mon Sep 17 00:00:00 2001 From: Vidhya Sudhan Loganathan Date: Fri, 16 Nov 2018 11:33:12 +0000 Subject: COMPMID-1266 : Add support for FP16 in CLWinogradConvolutionLayer: 5x5 kernels Introduced F32 accumulation for F16 winograd gemm and output transform WinogradConvolution will be available for F16 only if fast math flag is enabled Change-Id: I215593c205236a0f9669218437bb40b184ec6a4f --- src/core/CL/cl_kernels/gemm.cl | 348 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 348 insertions(+) (limited to 'src/core/CL/cl_kernels/gemm.cl') diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index d24f014f11..5d5cab6578 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -2297,6 +2297,354 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), } #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) +/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped + * + * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable. + * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y. + * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4. + * @note The number of matrix A columns must be passed at compile time using -DCOLS_A. + * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * + * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: + * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D + * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D + * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. + * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor + * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped + * + * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 + * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) + * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) + * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix + * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr + * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) + * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) + * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix + * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr + * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) + * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) + * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) + * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + */ +__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), + IMAGE_DECLARATION(src1), + IMAGE_DECLARATION(dst), + uint src0_stride_z, + uint src1_stride_z, + uint dst_stride_z +#if defined(REINTERPRET_INPUT_AS_3D) + , + uint src_cross_plane_pad +#endif // REINTERPRET_INPUT_AS_3D +#if defined(REINTERPRET_OUTPUT_AS_3D) + , + uint dst_cross_plane_pad +#endif // REINTERPRET_OUTPUT_AS_3D + ) +{ + int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X; + + // Compute starting address for matrix A and Matrix B + int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); + + // Update address for the matrix A + src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y; + + // Update address for the matrix B + src_addr.s1 += idx * sizeof(half); + +#if defined(REINTERPRET_INPUT_AS_3D) + // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D + uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zin = min(DEPTH_GEMM3D - 1, zin); + + // Add offset due to the cross plane paddings + zin *= (src_cross_plane_pad * src0_stride_y); + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply src0_stride_z by DEPTH_GEMM3D + src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; + +#else // defined(REINTERPRET_INPUT_AS_3D) + + // Add offset for batched GEMM + src_addr.s0 += get_global_id(2) * src0_stride_z; + +#endif // defined(REINTERPRET_INPUT_AS_3D) + +#if defined(MATRIX_B_DEPTH) + // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 + src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; +#else // defined(MATRIX_B_DEPTH) + src_addr.s1 += get_global_id(2) * src1_stride_z; +#endif // defined(MATRIX_B_DEPTH) + + float8 acc0 = 0.0h; +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + float8 acc1 = 0.0h; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + float8 acc2 = 0.0h; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + float8 acc3 = 0.0h; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + int i = 0; + for(; i <= ((int)COLS_A - 4); i += 4) + { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + + // Load values from matrix B + float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); + src_addr.s1 += src1_stride_y; + + // Accumulate + acc0 = fma(b0, (float8)a0.s0, acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (float8)a1.s0, acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (float8)a2.s0, acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (float8)a3.s0, acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); + src_addr.s1 += src1_stride_y; + acc0 = fma(b0, (float8)a0.s1, acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (float8)a1.s1, acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (float8)a2.s1, acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (float8)a3.s1, acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); + src_addr.s1 += src1_stride_y; + acc0 = fma(b0, (float8)a0.s2, acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (float8)a1.s2, acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (float8)a2.s2, acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (float8)a3.s2, acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); + src_addr.s1 += src1_stride_y; + acc0 = fma(b0, (float8)a0.s3, acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (float8)a1.s3, acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (float8)a2.s3, acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (float8)a3.s3, acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += 4 * sizeof(half); + } + + for(; i < (int)COLS_A; ++i) + { +#if defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#else // defined(REINTERPRET_INPUT_AS_3D) + // Load values from matrix A + half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // defined(REINTERPRET_INPUT_AS_3D) + + // Load values from matrix B + float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); + + src_addr += (int2)(sizeof(half), src1_stride_y); + + // Accumulate + acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0; +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + } + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + half8 hacc0 = convert_half8(acc0) * (half8)ALPHA; +#else //defined(ALPHA) + half8 hacc0 = convert_half8(acc0); +#endif // defined(ALPHA) +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if defined(ALPHA) + half8 hacc1 = convert_half8(acc1) * (half8)ALPHA; +#else //defined(ALPHA) + half8 hacc1 = convert_half8(acc1); +#endif //defined(ALPHA) +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if defined(ALPHA) + half8 hacc2 = convert_half8(acc2) * (half8)ALPHA; +#else //defined(ALPHA) + half8 hacc2 = convert_half8(acc2); +#endif //defined(ALPHA) +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#if defined(ALPHA) + half8 hacc3 = convert_half8(acc3) * (half8)ALPHA; +#else //defined(ALPHA) + half8 hacc3 = convert_half8(acc3); +#endif // defined(ALPHA) +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + int z = get_global_id(2); + + // Compute destination address + Image dst = CONVERT_TO_IMAGE_STRUCT(dst); + + // Compute dst address + __global uchar *dst_addr = offset(&dst, 0, 0); + +#if defined(REINTERPRET_OUTPUT_AS_3D) + // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension + // in order to take into account the presence of possible cross plane paddings + // + // | | + // | plane0 | + // | | + // |__________________| + // |******************| + // | cross_plane_pad | + // |******************| + // | | + // | plane1 | + // | | + // |__________________| + + // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D + uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); + + // Add offset due to the cross plane paddings + zout *= (dst_cross_plane_pad * dst_stride_y); + + // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we + // multiply dst_stride_z by DEPTH_GEMM3D + dst_addr += z * dst_stride_z * DEPTH_GEMM3D; + + // Store the output block + vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + +#else // defined(REINTERPRET_OUTPUT_AS_3D) + // Add offset for batched GEMM + dst_addr += z * dst_stride_z; + + // Store the output block + vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#endif // REINTERPRET_OUTPUT_AS_3D +} + /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped * * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units. -- cgit v1.2.1