aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2018-03-01 16:42:00 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:53:09 +0000
commit8e74f4488daf1b628ca718396d5fc72fea95a83d (patch)
treef372c61aab423799f82ea7a98aa6a157a4887bdc /src/core/CL/cl_kernels/gemm.cl
parent0a887922c73bbe7c5d42b1eb3ae55730f0d9a139 (diff)
downloadComputeLibrary-8e74f4488daf1b628ca718396d5fc72fea95a83d.tar.gz
COMPMID-911: Allow GEMM to work with 3D tensors
Change-Id: I8c4823a0d909e19e9ef548f00b9ae98c66de61dd Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123569 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl668
1 files changed, 565 insertions, 103 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index ad38c7ebd0..23681252d1 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -165,6 +165,12 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -183,13 +189,22 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -273,6 +288,40 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x4 block
+ vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -281,6 +330,7 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -293,6 +343,12 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -311,13 +367,22 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -533,6 +598,40 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x4 block
+ vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -541,6 +640,7 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
// Undefine local defines
@@ -556,6 +656,12 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -574,13 +680,22 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -664,6 +779,40 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x8 block
+ vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -672,6 +821,7 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -683,6 +833,12 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -701,13 +857,19 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -876,11 +1038,49 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x8 block
+ vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
// Store 4x8 block
vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
// Undefine local defines
@@ -917,6 +1117,9 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1039,6 +1242,9 @@ __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1138,6 +1344,12 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1156,13 +1368,22 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1271,36 +1492,85 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
acc0 = acc0 * (VECTOR_TYPE)ALPHA;
#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc1 = acc1 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc2 = acc2 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc3 = acc3 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store output block
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store output block
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc1 = acc1 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc2 = acc2 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc3 = acc3 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
#endif // defined(DATA_TYPE)
@@ -1314,6 +1584,12 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1332,13 +1608,22 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1585,6 +1870,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
src_addr.s0 += sizeof(float);
}
+ int z = get_global_id(2);
+
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
@@ -1595,46 +1882,83 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
acc02 = acc02 * ALPHA;
acc03 = acc03 * ALPHA;
#endif // defined(ALPHA)
-
- // Compute dst address
- __global uchar *dst_addr = offset(&dst, 0, 0);
-
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
- float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
- vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
acc10 = acc10 * ALPHA;
acc11 = acc11 * ALPHA;
acc12 = acc12 * ALPHA;
acc13 = acc13 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
- vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
acc20 = acc20 * ALPHA;
acc21 = acc21 * ALPHA;
acc22 = acc22 * ALPHA;
acc23 = acc23 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
- vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
acc30 = acc30 * ALPHA;
acc31 = acc31 * ALPHA;
acc32 = acc32 * ALPHA;
acc33 = acc33 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
- vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store the output block
+ vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
@@ -1648,6 +1972,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1666,13 +1996,22 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
// Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1857,46 +2196,87 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
src_addr.s0 += sizeof(float);
}
+ // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+ acc00 = acc00 * ALPHA;
+ acc01 = acc01 * ALPHA;
+#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc10 = acc10 * ALPHA;
+ acc11 = acc11 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc20 = acc20 * ALPHA;
+ acc21 = acc21 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc30 = acc30 * ALPHA;
+ acc31 = acc31 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
+ dst_addr += z * dst_stride_z;
- // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
- acc00 = acc00 * ALPHA;
- acc01 = acc01 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc0 = ((float2)(acc00, acc01));
- vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+ // Store the output block
+ vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc10 = acc10 * ALPHA;
- acc11 = acc11 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc1 = ((float2)(acc10, acc11));
- vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+ vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc20 = acc20 * ALPHA;
- acc21 = acc21 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc2 = ((float2)(acc20, acc21));
- vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+ vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc30 = acc30 * ALPHA;
- acc31 = acc31 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc3 = (float2)(acc30, acc31);
- vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+ vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
@@ -1910,6 +2290,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1928,13 +2314,22 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -2071,38 +2466,83 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
}
+ // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+ acc0 = acc0 * (half8)ALPHA;
+#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc1 = acc1 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc2 = acc2 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc3 = acc3 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
+ dst_addr += z * dst_stride_z;
- // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
- acc0 = acc0 * (half8)ALPHA;
-#endif // defined(ALPHA)
+ // Store the output block
vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc1 = acc1 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc2 = acc2 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc3 = acc3 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // REINTERPRET_OUTPUT_AS_3D
}
#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
@@ -2135,6 +2575,9 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -2319,6 +2762,9 @@ __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -2471,20 +2917,24 @@ __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
@@ -2509,20 +2959,24 @@ __kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
@@ -2550,20 +3004,24 @@ __kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_qs8(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
@@ -2589,20 +3047,24 @@ __kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_qs16(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
short8 alpha_ab = vload8(0, (__global short *)dst.ptr);