aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-07-26 11:44:03 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit68a3f56627b04acdefebe67d645727dd83889766 (patch)
tree4a3f4dc0facfda861a5ba7afa29d84d82d0829c2 /src
parent4e0d3819be6c61cc00c7e0fa9b4b740738c703b7 (diff)
downloadComputeLibrary-68a3f56627b04acdefebe67d645727dd83889766.tar.gz
COMPMID-1276 - Allow GEMM to work with 3D input tensor
Skipped im2col in CLGEMMConvolutionLayer for 1x1 convolutions with NHWC data layout Change-Id: I894e6b952ed8605e8f3ffc0ffc25c24730d4664c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141909 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl375
-rw-r--r--src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp73
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp47
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp32
-rw-r--r--src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp76
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp7
6 files changed, 507 insertions, 103 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 5a6efe64b9..932e0d681a 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -88,6 +88,11 @@ __kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
*
* @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
* @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
*
* @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -105,9 +110,15 @@ __kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
*/
__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
+ TENSOR3D_DECLARATION(dst)
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+ )
{
// Compute source and destination addresses
uint x = get_global_id(0);
@@ -124,6 +135,45 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
// Add offset for batched GEMM
dst_addr_in_bytes += z * dst_stride_z;
+#if defined(REINTERPRET_INPUT_AS_3D)
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
+
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (cross_plane_pad * src_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src_stride_z by DEPTH_GEMM3D
+ input_ptr += z * src_stride_z * DEPTH_GEMM3D;
+
+ // Load values from Matrix A
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
+#else // defined(REINTERPRET_INPUT_AS_3D)
__global uchar *input_ptr = src.ptr;
// Load values from Matrix A
@@ -135,6 +185,7 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
VEC_DATA_TYPE(DATA_TYPE, 4)
val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
@@ -188,7 +239,7 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -366,7 +417,7 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -679,7 +730,7 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -853,7 +904,7 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1095,7 +1146,8 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0)
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
* -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1122,7 +1174,8 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0)
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1130,9 +1183,13 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
uint src0_stride_z,
uint src1_stride_z,
uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
- uint cross_plane_pad
+ uint dst_cross_plane_pad
#endif // REINTERPRET_OUTPUT_AS_3D
)
{
@@ -1147,9 +1204,40 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(DATA_TYPE);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1172,6 +1260,23 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
VEC_DATA_TYPE(DATA_TYPE, 2)
a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
@@ -1187,6 +1292,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
VEC_DATA_TYPE(DATA_TYPE, 2)
a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
@@ -1210,6 +1317,19 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1221,6 +1341,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
@@ -1280,7 +1402,7 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
zout = min(DEPTH_GEMM3D - 1, zout);
// Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
+ zout *= (dst_cross_plane_pad * dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -1335,7 +1457,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
* -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1362,7 +1485,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1370,9 +1494,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
uint src0_stride_z,
uint src1_stride_z,
uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
- uint cross_plane_pad
+ uint dst_cross_plane_pad
#endif // REINTERPRET_OUTPUT_AS_3D
)
{
@@ -1387,9 +1515,40 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
// Update address for matrix B
src_addr.s1 += idx * sizeof(float);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1428,6 +1587,19 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
int i = 0;
for(; i <= ((int)COLS_A - 4); i += 4)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A and matrix B
+ float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A and matrix B
float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1439,6 +1611,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1579,8 +1753,21 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
- float a0 = *((__global float *)(src0_ptr + src_addr.s0));
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1590,6 +1777,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1676,7 +1865,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
zout = min(DEPTH_GEMM3D - 1, zout);
// Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
+ zout *= (dst_cross_plane_pad * dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -1723,7 +1912,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
* -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1750,7 +1940,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1758,9 +1949,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
uint src0_stride_z,
uint src1_stride_z,
uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
- uint cross_plane_pad
+ uint dst_cross_plane_pad
#endif // REINTERPRET_OUTPUT_AS_3D
)
{
@@ -1776,9 +1971,40 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(float);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1807,8 +2033,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
int i = 0;
for(; i <= ((int)COLS_A - 8); i += 8)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix B
float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
@@ -1848,7 +2079,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc01 = fma(a0.s7, b7.s1, acc01);
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc10 = fma(a0.s0, b0.s0, acc10);
acc10 = fma(a0.s1, b1.s0, acc10);
acc10 = fma(a0.s2, b2.s0, acc10);
@@ -1868,7 +2103,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc11 = fma(a0.s7, b7.s1, acc11);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc20 = fma(a0.s0, b0.s0, acc20);
acc20 = fma(a0.s1, b1.s0, acc20);
acc20 = fma(a0.s2, b2.s0, acc20);
@@ -1888,7 +2127,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc21 = fma(a0.s7, b7.s1, acc21);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc30 = fma(a0.s0, b0.s0, acc30);
acc30 = fma(a0.s1, b1.s0, acc30);
acc30 = fma(a0.s2, b2.s0, acc30);
@@ -1913,6 +2156,19 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
// float size increment
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1924,6 +2180,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1994,7 +2252,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
zout = min(DEPTH_GEMM3D - 1, zout);
// Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
+ zout *= (dst_cross_plane_pad * dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -2041,7 +2299,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
* -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -2068,7 +2327,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -2076,9 +2336,13 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
uint src0_stride_z,
uint src1_stride_z,
uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
- uint cross_plane_pad
+ uint dst_cross_plane_pad
#endif // REINTERPRET_OUTPUT_AS_3D
)
{
@@ -2093,9 +2357,40 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(half);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -2117,6 +2412,19 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
int i = 0;
for(; i <= ((int)COLS_A - 4); i += 4)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -2128,6 +2436,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -2188,6 +2498,19 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -2199,6 +2522,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
@@ -2260,7 +2585,7 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
zout = min(DEPTH_GEMM3D - 1, zout);
// Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
+ zout *= (dst_cross_plane_pad * dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
diff --git a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
index 12a40cd7dc..6ea1160c69 100644
--- a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
+++ b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
+#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/CLValidate.h"
@@ -30,6 +31,7 @@
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Window.h"
@@ -40,7 +42,7 @@ using namespace arm_compute::misc::shape_calculator;
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
ARM_COMPUTE_RETURN_ERROR_ON(mult_interleave4x4_height < 1);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
@@ -50,24 +52,30 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, i
if(output->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_interleaved_shape(*input, mult_interleave4x4_height));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_interleaved_shape(*input, mult_interleave4x4_height, reinterpret_input_as_3d));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
}
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, int mult_interleave4x4_height)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
constexpr unsigned int num_elems_processed_per_iteration_x = 4;
constexpr unsigned int num_elems_processed_per_iteration_y = 4;
const unsigned int num_elems_written_per_iteration = num_elems_processed_per_iteration_x * num_elems_processed_per_iteration_y * mult_interleave4x4_height;
bool window_changed = false;
- // Configure kernel window
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowRectangle input_access(input, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
- window_changed = window_changed || update_window_and_padding(win, input_access);
+ TensorInfo tmp_info(*input);
+
+ if(reinterpret_input_as_3d)
+ {
+ // Since the input tensor has to be reinterpreted as 3D and the execute window is based on a 2D interleave,
+ // the window needs to be constructed on the 2D collapsed version of the tensor
+ TensorShape tmp_shape(input->tensor_shape());
+ tmp_shape.collapse(2U, 1U);
+ tmp_info.set_tensor_shape(tmp_shape);
+ }
// Output auto inizialitation if not yet initialized
auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_interleaved_shape(*input, mult_interleave4x4_height)));
@@ -76,9 +84,22 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
const float scale_x = 4.0f * static_cast<float>(mult_interleave4x4_height);
const float scale_y = 1.0f / (scale_x);
+ // Note: bottom paddings are calculated manually as the input can be reinterpreted as 3D tensor
+ // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+ const int m = reinterpret_input_as_3d ? input->tensor_shape()[1] * input->tensor_shape()[2] : input->tensor_shape()[1];
+ const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
+ Window win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ Window win_in = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+
+ AccessWindowStatic input_access(input, 0, 0,
+ ceil_to_multiple(input->dimension(0), num_elems_processed_per_iteration_x),
+ input->dimension(1) + bottom_pad);
AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration, 1, scale_x, scale_y);
- window_changed = window_changed || update_window_and_padding(win, output_access);
- output_access.set_valid_region(win, input->valid_region());
+
+ window_changed = update_window_and_padding(win_in, input_access) || // window used by the execute_window_loop
+ update_window_and_padding(win, output_access); // window used to update the padding requirements of output tensor
+ output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
// Collapse along the Z direction
// This collapse needs to be here in order to tune the Z dimension of LWS
@@ -90,26 +111,31 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
} // namespace
CLGEMMInterleave4x4Kernel::CLGEMMInterleave4x4Kernel()
- : _input(nullptr), _output(nullptr)
+ : _input(nullptr), _output(nullptr), _reinterpret_input_as_3d(false)
{
}
-void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height)
+void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_interleaved_shape(*input->info(), mult_interleave4x4_height)));
+ auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_interleaved_shape(*input->info(), mult_interleave4x4_height, reinterpret_input_as_3d)));
// Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mult_interleave4x4_height));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mult_interleave4x4_height, reinterpret_input_as_3d));
- _input = input;
- _output = output;
+ _input = input;
+ _output = output;
+ _reinterpret_input_as_3d = reinterpret_input_as_3d;
// Create build options
CLBuildOptions build_opts;
build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(input->info()->dimension(1)));
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(input->info()->dimension(2)));
+
switch(input->info()->element_size())
{
case 1:
@@ -129,12 +155,13 @@ void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *out
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_interleave4x4", build_opts.options()));
// Configure kernel window
- auto win_config = validate_and_configure_window(input->info(), output->info(), mult_interleave4x4_height);
+ auto win_config = validate_and_configure_window(input->info(), output->info(), mult_interleave4x4_height, reinterpret_input_as_3d);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = "interleave4x4_";
+ _config_id += (_reinterpret_input_as_3d ? "3d_" : "");
_config_id += lower_string(string_from_data_type(input->info()->data_type()));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(0));
@@ -146,10 +173,10 @@ void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *out
_config_id += support::cpp11::to_string(output->info()->dimension(3));
}
-Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
+Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mult_interleave4x4_height));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), mult_interleave4x4_height).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mult_interleave4x4_height, reinterpret_input_as_3d));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), mult_interleave4x4_height, reinterpret_input_as_3d).first);
return Status{};
}
@@ -170,6 +197,14 @@ void CLGEMMInterleave4x4Kernel::run(const Window &window, cl::CommandQueue &queu
*/
Window slice = window.first_slice_window_3D();
+ if(_reinterpret_input_as_3d)
+ {
+ // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor
+ const unsigned int idx0 = 2 * num_arguments_per_3D_tensor();
+ const unsigned int total_cross_plane_pad = _input->info()->padding().top + _input->info()->padding().bottom;
+ _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+ }
+
do
{
unsigned int idx = 0;
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 0c629af788..c9e6bb34b2 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -56,6 +56,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
if(!is_interleaved_transposed)
{
@@ -125,6 +126,9 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
if(is_interleaved_transposed)
{
+ // reinterpret_input_as_3d is not supported if is_interleaved_transposed is set
+ ARM_COMPUTE_ERROR_ON(reshape_info.reinterpret_input_as_3d());
+
// Configure kernel window
num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
num_elems_processed_per_iteration_y = 4;
@@ -158,7 +162,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
// Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
// The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
- const int m = input0->tensor_shape()[1];
+ const int m = reshape_info.reinterpret_input_as_3d() ? input0->tensor_shape()[1] * input0->tensor_shape()[2] : input0->tensor_shape()[1];
const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
// Create kernels according to the architecture, data type and input size.
@@ -172,7 +176,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), ceil_to_multiple(input0->dimension(1), num_elems_processed_per_iteration_y));
+ AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), input0->dimension(1) + bottom_pad);
AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
AccessWindowStatic output_access(output, 0, 0,
ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
@@ -198,7 +202,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
} // namespace
CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
- : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _is_gemm3d(false)
+ : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false)
{
}
@@ -209,19 +213,22 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Perform validate step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
- _input0 = input0;
- _input1 = input1;
- _output = output;
- _slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions();
+ _input0 = input0;
+ _input1 = input1;
+ _output = output;
+ _reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
+ _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 1);
+
+ // Check if we need to slide the matrix B
+ const unsigned int num_dimensions_input0 = _reinterpret_input_as_3d ? _input0->info()->num_dimensions() - 1 : _input0->info()->num_dimensions();
+
+ _slide_matrix_b = (_input1->info()->num_dimensions() >= num_dimensions_input0);
const DataType data_type = input0->info()->data_type();
// Get target architecture
GPUTarget gpu_target = get_target();
- // Check if the output has to be reinterpreted as 3D
- _is_gemm3d = (reshape_info.depth_output_gemm3d() != 1) && is_data_type_float(data_type);
-
ElementsProcessed num_elements_processed{};
// Configure kernel window
@@ -237,9 +244,10 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
{
build_opts.add_option("-DALPHA=" + float_to_string_with_full_precision(alpha));
}
- build_opts.add_option_if(_is_gemm3d, "-DREINTERPRET_OUTPUT_AS_3D");
- build_opts.add_option_if(_is_gemm3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
- build_opts.add_option_if(_is_gemm3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
+ build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
// Do not slide matrix B if _slide_matrix_b = false
build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
@@ -305,7 +313,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Set config_id for enabling LWS tuning
_config_id = "gemm_";
_config_id += (is_interleaved_transposed ? "reshaped_" : "");
- _config_id += (_is_gemm3d ? "3d_" : "");
+ _config_id += (_reinterpret_input_as_3d ? "3di_" : "");
+ _config_id += (_reinterpret_output_as_3d ? "3do_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
@@ -355,10 +364,18 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
- if(_is_gemm3d)
+ if(_reinterpret_input_as_3d)
{
// Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3;
+ const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom;
+ _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+ }
+
+ if(_reinterpret_output_as_3d)
+ {
+ // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
+ const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom;
_kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
}
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 1f4df4f1a9..1d1b17bbf1 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -102,7 +102,8 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
// in order to know how the matrices have been reshaped
- const int m = a->info()->dimension(1);
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ const int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
const int n = b->info()->dimension(0);
const int k = a->info()->dimension(0);
const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
@@ -118,6 +119,12 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
// Check if we need to reshape the matrix A and matrix B
_is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
+ // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+ if(_is_interleaved_transposed)
+ {
+ reinterpret_input_as_3d = false;
+ }
+
if(_is_interleaved_transposed)
{
matrix_a = &_tmp_a;
@@ -132,14 +139,15 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
// _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
// Configure interleave kernel
- _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height);
+ _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
// Configure transpose kernel
_transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
}
// Configure and tune matrix multiply kernel
- _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d));
+ _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d,
+ reinterpret_input_as_3d));
CLScheduler::get().tune_kernel_static(_mm_kernel);
if(_is_interleaved_transposed)
@@ -180,11 +188,13 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
// in order to know how the matrices have been reshaped
- const int m = a->dimension(1);
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ const int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
const int n = b->dimension(0);
const int k = a->dimension(0);
int mult_transpose1xW_width = 1;
int mult_interleave4x4_height = 1;
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
{
@@ -192,19 +202,25 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
mult_interleave4x4_height = 2;
}
- const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
-
// Check if we need to reshape the matrix A and matrix B
const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
+ // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+ if(run_interleave_transpose)
+ {
+ reinterpret_input_as_3d = false;
+ }
+
+ const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+
if(run_interleave_transpose)
{
matrix_a_info = &tmp_a_info;
matrix_b_info = &tmp_b_info;
// Validate interleave kernel
- auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height)));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height));
+ auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
// Validate transpose kernel
auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index f1d2924c92..de628291eb 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -91,15 +91,15 @@ void CLConvolutionLayerReshapeWeights::run()
CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(),
- _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _skip_im2col(false), _is_quantized(false),
- _is_activationlayer_enabled(false), _is_prepared(false)
+ _add_bias_kernel(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false),
+ _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
{
}
void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
- ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info()));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info(), _skip_im2col));
if(_is_quantized)
{
@@ -120,15 +120,16 @@ void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTenso
else
{
// Configure matrix multiply function
- _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth));
+ _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth,
+ _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */));
}
}
-Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth)
+Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth, bool skip_im2col)
{
const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
- const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth);
+ const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
if(is_quantized)
{
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -180,7 +181,8 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
_original_weights = weights;
_is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
_data_layout = data_layout;
- _skip_im2col = false;
+ _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1) && !_is_quantized;
+ _append_bias = (biases != nullptr) && (!_is_quantized);
// Set the GPU target for im2col and col2im
_im2col_kernel.set_target(CLScheduler::get().target());
@@ -191,9 +193,8 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
ICLTensor *gemm_output_to_use = output;
ICLTensor *gemm_output_staged_to_use = output;
- const bool append_bias = (biases != nullptr) && (!_is_quantized);
- const unsigned bias_element = (append_bias) ? 1 : 0;
- const ICLTensor *biases_to_use = (append_bias) ? biases : nullptr;
+ const unsigned bias_element = (_append_bias && !_skip_im2col) ? 1 : 0;
+ const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
// Get parameters from conv_info
unsigned int stride_x = 0;
@@ -238,12 +239,17 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
_memory_group.manage(&_im2col_output);
// Configure and tune im2col
- _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation);
+ _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation);
CLScheduler::get().tune_kernel_static(_im2col_kernel);
// Update GEMM input
gemm_input_to_use = &_im2col_output;
}
+ else if(_append_bias)
+ {
+ // Configure add bias kernel
+ _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
+ }
// Create GEMM output tensor
if(!is_nhwc || _is_quantized)
@@ -281,28 +287,23 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
int output_multiplier, output_shift;
quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- if(!is_nhwc)
- {
- _memory_group.manage(&_tmp_output);
- gemm_output_staged_to_use = &_tmp_output;
- }
+
+ _memory_group.manage(&_tmp_output);
+ gemm_output_staged_to_use = &_tmp_output;
+
_gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset);
}
- if(!is_nhwc)
+ if(!is_nhwc || _is_quantized)
{
// Configure and tune Col2Im
_col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, std::make_pair(conv_w, conv_h));
CLScheduler::get().tune_kernel_static(_col2im_kernel);
}
- if(_is_quantized && !is_nhwc)
- {
- _tmp_output.allocator()->allocate();
- }
-
if(!is_nhwc || _is_quantized)
{
+ _tmp_output.allocator()->allocate();
_gemm_output.allocator()->allocate();
}
@@ -348,10 +349,10 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
const ITensorInfo *weights_to_use = weights;
const bool is_nhwc = data_layout == DataLayout::NHWC;
- const bool skip_im2col = false;
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
+ const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1) && !is_quantized;
const bool append_bias = (biases != nullptr) && (!is_quantized);
- const unsigned bias_element = (append_bias) ? 1 : 0;
+ const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0;
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -410,6 +411,11 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
gemm_input_to_use = &im2col_reshaped_info;
}
+ else if(append_bias)
+ {
+ // Validate add bias kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
+ }
// Create GEMM output tensor
if(!is_nhwc || is_quantized)
@@ -424,25 +430,24 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
gemm_output_to_use = &info_gemm;
}
- ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1, skip_im2col));
if(is_quantized)
{
float multiplier = input->quantization_info().scale * weights_to_use->quantization_info().scale / output->quantization_info().scale;
int output_multiplier, output_shift;
quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- if(!is_nhwc)
- {
- tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
- tmp_info.set_quantization_info(output->quantization_info());
- gemm_output_staged_to_use = &tmp_info;
- }
+
+ tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
+ tmp_info.set_quantization_info(output->quantization_info());
+ gemm_output_staged_to_use = &tmp_info;
+
// Validate output stage for quantized case
CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, output->quantization_info().offset);
}
// Validate Col2Im
- if(!is_nhwc)
+ if(!is_nhwc || is_quantized)
{
ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use,
output,
@@ -485,8 +490,13 @@ void CLGEMMConvolutionLayer::run()
_mm_gemm.run();
}
+ if(_skip_im2col && _append_bias)
+ {
+ CLScheduler::get().enqueue(_add_bias_kernel);
+ }
+
// Reshape output matrix
- if(_data_layout == DataLayout::NCHW)
+ if(_data_layout == DataLayout::NCHW || _is_quantized)
{
CLScheduler::get().enqueue(_col2im_kernel, false);
}
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 842ee73397..c2e18a760a 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -205,16 +205,17 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
const int k = a->dimension(0);
constexpr int mult_transpose1xW_width = 1;
constexpr int mult_interleave4x4_height = 1;
- const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height);
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
+ const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d);
bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
if(reshape_matrices)
{
- TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height), 1, a->data_type());
+ TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()), 1, a->data_type());
TensorInfo info_b(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width), 1, b->data_type());
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMTranspose1xWKernel::validate(b, &info_b, mult_transpose1xW_width));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output, reshape_matrices, reshape_info));
}