aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2018-03-01 16:42:00 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:53:09 +0000
commit8e74f4488daf1b628ca718396d5fc72fea95a83d (patch)
treef372c61aab423799f82ea7a98aa6a157a4887bdc
parent0a887922c73bbe7c5d42b1eb3ae55730f0d9a139 (diff)
downloadComputeLibrary-8e74f4488daf1b628ca718396d5fc72fea95a83d.tar.gz
COMPMID-911: Allow GEMM to work with 3D tensors
Change-Id: I8c4823a0d909e19e9ef548f00b9ae98c66de61dd Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123569 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h1
-rw-r--r--arm_compute/core/Types.h46
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h23
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMM.h7
-rw-r--r--src/core/CL/cl_kernels/gemm.cl668
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp8
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp98
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp9
-rw-r--r--tests/datasets/LargeGEMMDataset.h16
-rw-r--r--tests/datasets/MatrixMultiplyGEMMDataset.h2
-rw-r--r--tests/datasets/SmallGEMMDataset.h13
-rw-r--r--tests/validation/CL/GEMM.cpp40
-rw-r--r--tests/validation/fixtures/GEMMFixture.h9
-rw-r--r--utils/TypePrinter.h1
14 files changed, 767 insertions, 174 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
index 13802b97ad..15bba0cd0f 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
@@ -85,6 +85,7 @@ public:
const ICLTensor *_input1;
ICLTensor *_output;
bool _slide_matrix_b;
+ bool _is_gemm3d;
};
} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLGEMMMATRIXMULTIPLYKERNEL_H__ */
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 12c4e25222..da28e131de 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1013,7 +1013,7 @@ class GEMMReshapeInfo final
public:
/** Default constructor */
GEMMReshapeInfo()
- : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1)
+ : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(1)
{
}
/** Constructor
@@ -1023,9 +1023,10 @@ public:
* @param[in] k Number of matrix A columns or matrix B rows
* @param[in] mult_transpose1xW_width (Optional) Multiplication factor for the width of the 1xW transposed block
* @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleaved block
+ * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
*/
- GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1)
- : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height)
+ GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 1)
+ : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d)
{
}
/** Number of matrix A rows
@@ -1068,6 +1069,17 @@ public:
{
return _mult_interleave4x4_height;
}
+ /** Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
+ *
+ * @note GEMM3D kernel is used when the output has to be reinterpret as 3D tensor. In that case:
+ * m = depth_output_gemm3d * output_height
+ *
+ * @return the depth of the output tensor to be used with the GEMM3D kernel
+ */
+ int depth_output_gemm3d() const
+ {
+ return _depth_output_gemm3d;
+ }
private:
const int _m;
@@ -1075,6 +1087,7 @@ private:
const int _k;
const int _mult_transpose1xW_width;
const int _mult_interleave4x4_height;
+ const int _depth_output_gemm3d;
};
/** GEMM information class. This class stores the necessary information to compute GEMM functions
@@ -1087,7 +1100,7 @@ class GEMMInfo
public:
/** Default constructor */
GEMMInfo()
- : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _reshape_info()
+ : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1)
{
}
/** Constructor
@@ -1095,10 +1108,11 @@ public:
* @param[in] is_a_reshaped True if the matrix A has been reshaped
* @param[in] is_b_reshaped True if the matrix B has been reshaped
* @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
- * @param[in] reshape_info (Optional) GEMM reshape information object
+ * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
+ *
*/
- GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo())
- : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _reshape_info(reshape_info)
+ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1)
+ : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -1127,20 +1141,20 @@ public:
{
return _reshape_b_only_on_first_run;
};
- /** GEMMReshapeInfo object which stores the necessary information to understand how the matrix A and matrix B have been reshaped
+ /** Depth of the output when GEMM output is reinterpreted as 3D tensor
*
- * @return the GEMMReshapeInfo object
+ * @return the depth of the output tensor
*/
- const GEMMReshapeInfo &reshape_info() const
+ int depth_output_gemm3d() const
{
- return _reshape_info;
- }
+ return _depth_output_gemm3d;
+ };
private:
- const bool _is_a_reshaped;
- const bool _is_b_reshaped;
- const bool _reshape_b_only_on_first_run;
- GEMMReshapeInfo _reshape_info;
+ const bool _is_a_reshaped;
+ const bool _is_b_reshaped;
+ const bool _reshape_b_only_on_first_run;
+ const int _depth_output_gemm3d;
};
/** Winograd information */
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index deab181aee..9666702749 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -360,11 +360,26 @@ inline TensorShape compute_rnn_shape(const ITensorInfo *input, const unsigned in
}
inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
{
- TensorShape tensor_shape{ input0.tensor_shape() };
- tensor_shape.set(0, is_interleaved_transposed ? reshape_info.n() : input1.dimension(0));
- tensor_shape.set(1, is_interleaved_transposed ? reshape_info.m() : input0.dimension(1));
+ ARM_COMPUTE_ERROR_ON_MSG(input0.num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
- return tensor_shape;
+ const bool is_gemm3d = reshape_info.depth_output_gemm3d() != 1;
+
+ // If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third
+ // dimension of the output tensor
+ const int dim0 = is_interleaved_transposed ? reshape_info.n() : input1.dimension(0);
+ const int dim1 = is_interleaved_transposed ? reshape_info.m() / reshape_info.depth_output_gemm3d() : input0.dimension(1) / reshape_info.depth_output_gemm3d();
+ const int dim2 = input0.tensor_shape()[2];
+ const int dim3 = input0.tensor_shape()[3];
+
+ TensorShape output_shape{ input0.tensor_shape() };
+
+ output_shape.set(0, dim0);
+ output_shape.set(1, dim1);
+ output_shape.set(2, is_gemm3d ? reshape_info.depth_output_gemm3d() : dim2);
+ output_shape.set(3, is_gemm3d ? dim2 : dim3);
+ output_shape.set(4, is_gemm3d ? dim3 : 1);
+
+ return output_shape;
}
template <typename T>
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 7a145cc183..41c7467a3f 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -24,7 +24,6 @@
#ifndef __ARM_COMPUTE_CLGEMM_H__
#define __ARM_COMPUTE_CLGEMM_H__
-#include "arm_compute/core/CL/kernels/CLFillBorderKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
@@ -34,16 +33,14 @@
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
-#include <memory>
-
namespace arm_compute
{
class ICLTensor;
/** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels:
*
- * -# @ref CLGEMMInterleave4x4Kernel (if the output tensor is a matrix)
- * -# @ref CLGEMMTranspose1xWKernel (if the output tensor is a matrix)
+ * -# @ref CLGEMMInterleave4x4Kernel (only if the reshaped GEMM is selected by the heuristic model)
+ * -# @ref CLGEMMTranspose1xWKernel (only if the reshaped GEMM is selected by the heuristic model)
* -# @ref CLGEMMMatrixMultiplyKernel
* -# @ref CLGEMMMatrixAdditionKernel (if c != nullptr and beta != 0.0)
*
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index ad38c7ebd0..23681252d1 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -165,6 +165,12 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -183,13 +189,22 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -273,6 +288,40 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x4 block
+ vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -281,6 +330,7 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -293,6 +343,12 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -311,13 +367,22 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -533,6 +598,40 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x4 block
+ vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -541,6 +640,7 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
// Undefine local defines
@@ -556,6 +656,12 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -574,13 +680,22 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -664,6 +779,40 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x8 block
+ vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -672,6 +821,7 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -683,6 +833,12 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -701,13 +857,19 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -876,11 +1038,49 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store 4x8 block
+ vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+ vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+ vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+ vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
// Store 4x8 block
vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
// Undefine local defines
@@ -917,6 +1117,9 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1039,6 +1242,9 @@ __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1138,6 +1344,12 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1156,13 +1368,22 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1271,36 +1492,85 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
acc0 = acc0 * (VECTOR_TYPE)ALPHA;
#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc1 = acc1 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc2 = acc2 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc3 = acc3 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store output block
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store output block
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc1 = acc1 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc2 = acc2 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc3 = acc3 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
#endif // defined(DATA_TYPE)
@@ -1314,6 +1584,12 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1332,13 +1608,22 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1585,6 +1870,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
src_addr.s0 += sizeof(float);
}
+ int z = get_global_id(2);
+
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
@@ -1595,46 +1882,83 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
acc02 = acc02 * ALPHA;
acc03 = acc03 * ALPHA;
#endif // defined(ALPHA)
-
- // Compute dst address
- __global uchar *dst_addr = offset(&dst, 0, 0);
-
- // Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
-
- float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
- vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
acc10 = acc10 * ALPHA;
acc11 = acc11 * ALPHA;
acc12 = acc12 * ALPHA;
acc13 = acc13 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
- vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
acc20 = acc20 * ALPHA;
acc21 = acc21 * ALPHA;
acc22 = acc22 * ALPHA;
acc23 = acc23 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
- vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
acc30 = acc30 * ALPHA;
acc31 = acc31 * ALPHA;
acc32 = acc32 * ALPHA;
acc33 = acc33 * ALPHA;
-#endif // defined(ALPHA)
- float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
- vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store the output block
+ vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
@@ -1648,6 +1972,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1666,13 +1996,22 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
// Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1857,46 +2196,87 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
src_addr.s0 += sizeof(float);
}
+ // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+ acc00 = acc00 * ALPHA;
+ acc01 = acc01 * ALPHA;
+#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc10 = acc10 * ALPHA;
+ acc11 = acc11 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc20 = acc20 * ALPHA;
+ acc21 = acc21 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc30 = acc30 * ALPHA;
+ acc31 = acc31 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
+ dst_addr += z * dst_stride_z;
- // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
- acc00 = acc00 * ALPHA;
- acc01 = acc01 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc0 = ((float2)(acc00, acc01));
- vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+ // Store the output block
+ vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc10 = acc10 * ALPHA;
- acc11 = acc11 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc1 = ((float2)(acc10, acc11));
- vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+ vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc20 = acc20 * ALPHA;
- acc21 = acc21 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc2 = ((float2)(acc20, acc21));
- vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+ vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc30 = acc30 * ALPHA;
- acc31 = acc31 * ALPHA;
-#endif // defined(ALPHA)
- float2 acc3 = (float2)(acc30, acc31);
- vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+ vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
}
#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
@@ -1910,6 +2290,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
* @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1928,13 +2314,22 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] pad_bottom Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
IMAGE_DECLARATION(dst),
uint src0_stride_z,
uint src1_stride_z,
- uint dst_stride_z)
+ uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -2071,38 +2466,83 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
}
+ // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+ acc0 = acc0 * (half8)ALPHA;
+#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+ acc1 = acc1 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+ acc2 = acc2 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+ acc3 = acc3 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+ int z = get_global_id(2);
+
// Compute destination address
Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible bottom paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |_____________|
+ // |*************|
+ // | pad_bottom |
+ // |*************|
+ // | |
+ // | plane1 |
+ // | |
+ // |_____________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the bottom paddings
+ zout *= (pad_bottom * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
- dst_addr += get_global_id(2) * dst_stride_z;
+ dst_addr += z * dst_stride_z;
- // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
- acc0 = acc0 * (half8)ALPHA;
-#endif // defined(ALPHA)
+ // Store the output block
vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
- acc1 = acc1 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
- acc2 = acc2 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
- acc3 = acc3 * (half8)ALPHA;
-#endif // defined(ALPHA)
vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // REINTERPRET_OUTPUT_AS_3D
}
#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
@@ -2135,6 +2575,9 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -2319,6 +2762,9 @@ __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
*/
__kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -2471,20 +2917,24 @@ __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
@@ -2509,20 +2959,24 @@ __kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
@@ -2550,20 +3004,24 @@ __kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_qs8(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
@@ -2589,20 +3047,24 @@ __kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
- IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_qs16(TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
{
// Compute source and destination addresses
- Image src = CONVERT_TO_IMAGE_STRUCT(src);
- Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
// Load values from A x B
short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
diff --git a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
index e6a1bafa72..c50ee24a70 100644
--- a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
@@ -126,14 +126,14 @@ void CLGEMMMatrixAdditionKernel::run(const Window &window, cl::CommandQueue &que
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
- Window slice = window.first_slice_window_2D();
+ Window slice = window.first_slice_window_3D();
do
{
unsigned int idx = 0;
- add_2D_tensor_argument(idx, _input, slice);
- add_2D_tensor_argument(idx, _output, slice);
+ add_3D_tensor_argument(idx, _input, slice);
+ add_3D_tensor_argument(idx, _output, slice);
enqueue(queue, *this, slice);
}
- while(window.slide_window_slice_2D(slice));
+ while(window.slide_window_slice_3D(slice));
}
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index fc52f4e124..2c2a92d070 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -56,19 +56,13 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_fixed_point(input0->data_type()) && (reshape_info.depth_output_gemm3d() != 1), "GEMM3D only supports floating point data types");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
if(!is_interleaved_transposed)
{
ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
-
- if(output->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
- ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
- }
}
else
{
@@ -94,14 +88,14 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+ }
- if(output->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
- ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
- }
+ if(output->total_size() != 0)
+ {
+ const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
}
return Status{};
@@ -113,6 +107,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
{
bool window_changed = false;
Window win{};
+ Window win_out{};
const DataType data_type = input0->data_type();
unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
@@ -121,23 +116,43 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
// Output tensor auto inizialitation if not yet initialized
auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info)));
+ TensorInfo tmp_info(*output);
+
+ if(reshape_info.depth_output_gemm3d() != 1)
+ {
+ // Since the output tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM,
+ // the window needs to be constructed on the 2D collapsed version of the tensor
+ TensorShape tmp_shape(output->tensor_shape());
+ tmp_shape.collapse(2U, 1U);
+ tmp_info.set_tensor_shape(tmp_shape);
+ }
+
if(is_interleaved_transposed)
{
// Configure kernel window
num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
num_elems_processed_per_iteration_y = 4;
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
+ // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+ const int m = reshape_info.m();
+ const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
+ win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
AccessWindowRectangle input0_access(input0, 0, 0, num_elems_processed_per_iteration_y, 1, 1.f, 0.25f);
AccessWindowStatic input1_access(input1, 0, 0,
ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x),
ceil_to_multiple(input1->dimension(1), num_elems_processed_per_iteration_y));
- AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+ AccessWindowStatic output_access(output, 0, 0,
+ ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
+ output->dimension(1) + bottom_pad);
- window_changed = update_window_and_padding(win, input0_access, input1_access, output_access);
+ window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
+ update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
- output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
+ output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
}
else // The input tensors have not been reshaped
{
@@ -145,6 +160,11 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 4);
+ // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
+ // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+ const int m = input0->tensor_shape()[1];
+ const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
// Create kernels according to the architecture, data type and input size.
GPUTarget arch_target = get_arch_from_target(gpu_target);
if(arch_target == GPUTarget::BIFROST && data_type == DataType::F32)
@@ -153,17 +173,21 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
}
// Configure window
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), ceil_to_multiple(input0->dimension(1), num_elems_processed_per_iteration_y));
- AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
- AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+ AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), ceil_to_multiple(input0->dimension(1), num_elems_processed_per_iteration_y));
+ AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
+ AccessWindowStatic output_access(output, 0, 0,
+ ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
+ output->dimension(1) + bottom_pad);
- window_changed = update_window_and_padding(win, input0_access, input1_access, output_access);
+ window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
+ update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
Coordinates coord;
coord.set_num_dimensions(output->num_dimensions());
- output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
+ output_access.set_valid_region(win_out, ValidRegion(coord, output->tensor_shape()));
}
// Collapse along the Z direction
@@ -178,7 +202,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
} // namespace
CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
- : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
+ : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _is_gemm3d(false)
{
}
@@ -194,9 +218,14 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
_output = output;
_slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions();
- const DataType data_type = input0->info()->data_type();
- const int fp_pos = input0->info()->fixed_point_position();
- const GPUTarget gpu_target = get_target();
+ const DataType data_type = input0->info()->data_type();
+ const int fp_pos = input0->info()->fixed_point_position();
+
+ // Get target architecture
+ GPUTarget gpu_target = get_target();
+
+ // Check if the output has to be reinterpreted as 3D
+ _is_gemm3d = (reshape_info.depth_output_gemm3d() != 1) && is_data_type_float(data_type);
ElementsProcessed num_elements_processed{};
@@ -216,6 +245,9 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
"-DALPHA=" + support::cpp11::to_string((data_type == DataType::QS8 ? sqcvt_qs8_f32(alpha, fp_pos) : sqcvt_qs16_f32(alpha, fp_pos))),
"-DALPHA=" + float_to_string_with_full_precision(alpha));
}
+ build_opts.add_option_if(_is_gemm3d, "-DREINTERPRET_OUTPUT_AS_3D");
+ build_opts.add_option_if(_is_gemm3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
+ build_opts.add_option_if(_is_gemm3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
// Do not slide matrix B if _slide_matrix_b = false
build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
@@ -285,6 +317,7 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Set config_id for enabling LWS tuning
_config_id = "gemm_";
_config_id += (is_interleaved_transposed ? "reshaped_" : "");
+ _config_id += (_is_gemm3d ? "3d_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
@@ -334,6 +367,13 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
+ if(_is_gemm3d)
+ {
+ // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
+ const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3;
+ _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(_output->info()->padding().bottom));
+ }
+
do
{
Window slice_b = slice;
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index a0ec66f804..f9713bb586 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -24,10 +24,6 @@
#include "arm_compute/runtime/CL/functions/CLGEMM.h"
#include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMTranspose1xWKernel.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/GPUTarget.h"
#include "arm_compute/core/Helpers.h"
@@ -111,6 +107,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
const int m = a->info()->dimension(1);
const int n = b->info()->dimension(0);
const int k = a->info()->dimension(0);
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
int mult_transpose1xW_width = 1;
int mult_interleave4x4_height = 1;
@@ -144,7 +141,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
}
// Configure and tune matrix multiply kernel
- _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height));
+ _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d));
CLScheduler::get().tune_kernel_static(_mm_kernel);
if(_is_interleaved_transposed)
@@ -197,7 +194,7 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
mult_interleave4x4_height = 2;
}
- const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height);
+ const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
// Check if we need to reshape the matrix A and matrix B
const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index 638eb904ad..bc9a056e7c 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -48,6 +48,20 @@ public:
add_config(TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 0.4f, 0.7f);
}
};
+class LargeGEMM3DDataset final : public GEMMDataset
+{
+public:
+ LargeGEMM3DDataset()
+ {
+ add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
+ add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
+ add_config(TensorShape(364, 3025), TensorShape(96, 364), TensorShape(96, 605, 5), TensorShape(96, 605, 5), 1.0f, 0.0f);
+ add_config(TensorShape(1201, 729), TensorShape(128, 1201), TensorShape(128, 243, 3), TensorShape(128, 243, 3), 1.0f, 0.0f);
+ add_config(TensorShape(2305, 169), TensorShape(384, 2305), TensorShape(384, 13, 13), TensorShape(384, 13, 13), 1.0f, 0.0f);
+ add_config(TensorShape(1729, 170), TensorShape(192, 1729), TensorShape(192, 85, 2), TensorShape(192, 85, 2), 1.0f, 0.0f);
+ add_config(TensorShape(1729, 170), TensorShape(128, 1729), TensorShape(128, 17, 10), TensorShape(128, 17, 10), 1.0f, 0.0f);
+ }
+};
} // namespace datasets
} // namespace test
} // namespace arm_compute
diff --git a/tests/datasets/MatrixMultiplyGEMMDataset.h b/tests/datasets/MatrixMultiplyGEMMDataset.h
index 5693706ab3..fd2a3d6610 100644
--- a/tests/datasets/MatrixMultiplyGEMMDataset.h
+++ b/tests/datasets/MatrixMultiplyGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index 110868bcbd..c9bf674ad3 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -49,6 +49,19 @@ public:
add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U), TensorShape(17U, 1U), 0.4f, 0.7f);
}
};
+class SmallGEMM3DDataset final : public GEMMDataset
+{
+public:
+ SmallGEMM3DDataset()
+ {
+ add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U, 7U, 2U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U, 1U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U, 4U, 3U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f);
+ add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U, 1U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f);
+ add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U, 8U, 2U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U, 8U, 2U, 5U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f);
+ }
+};
} // namespace datasets
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/CL/GEMM.cpp b/tests/validation/CL/GEMM.cpp
index 217edf4438..4e6cf826fa 100644
--- a/tests/validation/CL/GEMM.cpp
+++ b/tests/validation/CL/GEMM.cpp
@@ -250,8 +250,44 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixedPointFixture<int16_t>, framework::Da
TEST_SUITE_END()
TEST_SUITE_END()
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE(OUTPUT_3D)
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMM3DDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMM3DDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+}
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMM3DDataset(),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMM3DDataset(),
+ framework::dataset::make("DataType",
+ DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+TEST_SUITE_END() // FP16
+
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // OUTPUT_3D
+
+TEST_SUITE_END() // GEMM
+TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index e807ad8c06..8dd2998377 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -85,7 +85,12 @@ protected:
// Create and configure function
FunctionType gemm;
- gemm.configure(&a, &b, &c, &dst, alpha, beta);
+ // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output.
+ // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 1),
+ // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output).
+ bool is_output_reinterpreted_as_3D = output_shape.num_dimensions() > shape_a.num_dimensions();
+ gemm.configure(&a, &b, &c, &dst, alpha, beta,
+ GEMMInfo(false, false, false, is_output_reinterpreted_as_3D ? output_shape[2] : 1));
ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 7638759b46..f74623ea69 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -952,7 +952,6 @@ inline ::std::ostream &operator<<(::std::ostream &os, const GEMMInfo &info)
os << "{is_a_reshaped=" << info.is_a_reshaped() << ",";
os << "is_b_reshaped=" << info.is_b_reshaped() << ",";
os << "reshape_b_only_on_first_run=" << info.reshape_b_only_on_first_run() << ",";
- os << "reshape_info=" << info.reshape_info();
os << "}";
return os;