aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-07-26 11:44:03 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit68a3f56627b04acdefebe67d645727dd83889766 (patch)
tree4a3f4dc0facfda861a5ba7afa29d84d82d0829c2
parent4e0d3819be6c61cc00c7e0fa9b4b740738c703b7 (diff)
downloadComputeLibrary-68a3f56627b04acdefebe67d645727dd83889766.tar.gz
COMPMID-1276 - Allow GEMM to work with 3D input tensor
Skipped im2col in CLGEMMConvolutionLayer for 1x1 convolutions with NHWC data layout Change-Id: I894e6b952ed8605e8f3ffc0ffc25c24730d4664c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141909 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h7
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h3
-rw-r--r--arm_compute/core/Types.h48
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h31
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h23
-rw-r--r--src/core/CL/cl_kernels/gemm.cl375
-rw-r--r--src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp73
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp47
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp32
-rw-r--r--src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp76
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp7
-rw-r--r--tests/datasets/LargeGEMMDataset.h30
-rw-r--r--tests/datasets/SmallGEMMDataset.h18
-rw-r--r--tests/validation/CL/GEMM.cpp60
-rw-r--r--tests/validation/fixtures/GEMMFixture.h28
15 files changed, 683 insertions, 175 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h b/arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h
index 7f8e766f1a..4592fc2921 100644
--- a/arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h
@@ -67,17 +67,19 @@ public:
* @param[in] input Input tensor. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[out] output Output tensor. Data type supported: same as @p input
* @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleave block
+ * @param[in] reinterpret_input_as_3d (Optional) True if the input has to be reinterpreted as 3D tensor
*/
- void configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height = 1);
+ void configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height = 1, bool reinterpret_input_as_3d = false);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMInterleave4x4Kernel
*
* @param[in] input Input tensor info. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] output Output tensor info which stores the interleaved matrix. Data type supported: same as @p input.
* @param[in] mult_interleave4x4_height Multiplication factor for the height of the 4x4 interleave block
+ * @param[in] reinterpret_input_as_3d (Optional) True if the input has to be reinterpreted as 3D tensor
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d);
// Inherited methods overridden
void run(const Window &window, cl::CommandQueue &queue) override;
@@ -85,6 +87,7 @@ public:
private:
const ICLTensor *_input;
ICLTensor *_output;
+ bool _reinterpret_input_as_3d;
};
} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLGEMMINTERLEAVE4X4KERNEL_H__ */
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
index 1b6a0c87a9..e030fa2d2a 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
@@ -85,7 +85,8 @@ public:
const ICLTensor *_input1;
ICLTensor *_output;
bool _slide_matrix_b;
- bool _is_gemm3d;
+ bool _reinterpret_input_as_3d;
+ bool _reinterpret_output_as_3d;
};
} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLGEMMMATRIXMULTIPLYKERNEL_H__ */
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 00370918bd..81d652dd7d 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1031,7 +1031,7 @@ class GEMMReshapeInfo final
public:
/** Default constructor */
GEMMReshapeInfo()
- : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(1)
+ : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false)
{
}
/** Constructor
@@ -1042,9 +1042,12 @@ public:
* @param[in] mult_transpose1xW_width (Optional) Multiplication factor for the width of the 1xW transposed block
* @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleaved block
* @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
+ * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
+ * to perform 1x1 convolutions with the NHWC data layout)
*/
- GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 1)
- : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d)
+ GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false)
+ : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d),
+ _reinterpret_input_as_3d(reinterpret_input_as_3d)
{
}
/** Number of matrix A rows
@@ -1098,14 +1101,23 @@ public:
{
return _depth_output_gemm3d;
}
+ /** Flag which specifies if the input tensor has to be reinterpreted as 3D
+ *
+ * @return True if the input tensor has to be reinterpreted as 3D tensor
+ */
+ bool reinterpret_input_as_3d() const
+ {
+ return _reinterpret_input_as_3d;
+ };
private:
- const int _m;
- const int _n;
- const int _k;
- const int _mult_transpose1xW_width;
- const int _mult_interleave4x4_height;
- const int _depth_output_gemm3d;
+ const int _m;
+ const int _n;
+ const int _k;
+ const int _mult_transpose1xW_width;
+ const int _mult_interleave4x4_height;
+ const int _depth_output_gemm3d;
+ const bool _reinterpret_input_as_3d;
};
/** GEMM information class. This class stores the necessary information to compute GEMM functions
@@ -1118,7 +1130,7 @@ class GEMMInfo
public:
/** Default constructor */
GEMMInfo()
- : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1)
+ : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false)
{
}
/** Constructor
@@ -1127,10 +1139,13 @@ public:
* @param[in] is_b_reshaped True if the matrix B has been reshaped
* @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run
* @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel
+ * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
+ * to perform 1x1 convolutions with the NHWC data layout)
*
*/
- GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1)
- : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d)
+ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false)
+ : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d),
+ _reinterpret_input_as_3d(reinterpret_input_as_3d)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -1167,12 +1182,21 @@ public:
{
return _depth_output_gemm3d;
};
+ /** Flag which specifies if the input tensor has to be reinterpreted as 3D
+ *
+ * @return True if the input tensor has to be reinterpreted as 3D tensor
+ */
+ bool reinterpret_input_as_3d() const
+ {
+ return _reinterpret_input_as_3d;
+ };
private:
const bool _is_a_reshaped;
const bool _is_b_reshaped;
const bool _reshape_b_only_on_first_run;
const int _depth_output_gemm3d;
+ const bool _reinterpret_input_as_3d;
};
/** Winograd information */
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index dbf26a423d..bf55add1d2 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -66,14 +66,24 @@ inline TensorShape compute_weights_reshaped_shape(const ITensorInfo &weights, bo
return weights_reshaped;
}
-inline TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height = 1)
+inline TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height = 1, bool reinterpret_input_as_3d = false)
{
// The interleaved output matrix will have the following shape: [ a_height * W, ceil(a_width / W) ] where W = 4 * mult_interleave4x4_height
ARM_COMPUTE_ERROR_ON(mult_interleave4x4_height < 1);
const int interleave_width = 4 * mult_interleave4x4_height;
TensorShape shape_interleaved_a{ a.tensor_shape() };
shape_interleaved_a.set(0, a.dimension(0) * interleave_width);
- shape_interleaved_a.set(1, std::ceil(a.dimension(1) / static_cast<float>(interleave_width)));
+ if(reinterpret_input_as_3d)
+ {
+ const int M = a.dimension(1) * a.dimension(2);
+ const int height = std::ceil(M / static_cast<float>(interleave_width));
+ shape_interleaved_a.set(1, height);
+ shape_interleaved_a.remove_dimension(2);
+ }
+ else
+ {
+ shape_interleaved_a.set(1, std::ceil(a.dimension(1) / static_cast<float>(interleave_width)));
+ }
return shape_interleaved_a;
}
@@ -374,23 +384,26 @@ inline TensorShape compute_rnn_shape(const ITensorInfo *input, const unsigned in
inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
{
ARM_COMPUTE_ERROR_ON_MSG(input0.num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
+ ARM_COMPUTE_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The first input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
- const bool is_gemm3d = reshape_info.depth_output_gemm3d() != 1;
+ const bool reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
+ const bool reinterpret_output_as_3d = reshape_info.depth_output_gemm3d() != 1;
+ const int m = reshape_info.reinterpret_input_as_3d() ? input0.dimension(1) * input0.dimension(2) : input0.dimension(1);
// If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third
// dimension of the output tensor
const int dim0 = is_interleaved_transposed ? reshape_info.n() : input1.dimension(0);
- const int dim1 = is_interleaved_transposed ? reshape_info.m() / reshape_info.depth_output_gemm3d() : input0.dimension(1) / reshape_info.depth_output_gemm3d();
- const int dim2 = input0.tensor_shape()[2];
- const int dim3 = input0.tensor_shape()[3];
+ const int dim1 = is_interleaved_transposed ? reshape_info.m() / reshape_info.depth_output_gemm3d() : m / reshape_info.depth_output_gemm3d();
+ const int dim2 = reinterpret_input_as_3d ? input0.tensor_shape()[3] : input0.tensor_shape()[2];
+ const int dim3 = reinterpret_input_as_3d ? 1 : input0.tensor_shape()[3];
TensorShape output_shape{ input0.tensor_shape() };
output_shape.set(0, dim0);
output_shape.set(1, dim1);
- output_shape.set(2, is_gemm3d ? reshape_info.depth_output_gemm3d() : dim2);
- output_shape.set(3, is_gemm3d ? dim2 : dim3);
- output_shape.set(4, is_gemm3d ? dim3 : 1);
+ output_shape.set(2, reinterpret_output_as_3d ? reshape_info.depth_output_gemm3d() : dim2);
+ output_shape.set(3, reinterpret_output_as_3d ? dim2 : dim3);
+ output_shape.set(4, reinterpret_output_as_3d ? dim3 : 1);
return output_shape;
}
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
index 09daa5f568..7c272a348b 100644
--- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
@@ -26,8 +26,8 @@
#include "arm_compute/runtime/IFunction.h"
+#include "arm_compute/core/CL/kernels/CLArithmeticAdditionKernel.h"
#include "arm_compute/core/CL/kernels/CLCol2ImKernel.h"
-#include "arm_compute/core/CL/kernels/CLFillBorderKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMTranspose1xWKernel.h"
@@ -83,18 +83,12 @@ private:
/** Basic function to compute the convolution layer. This function calls the following OpenCL kernels/functions:
*
- * Note: weights already reshaped for quantized asymmetric is not supported
- *
* -# @ref CLIm2ColKernel
- * -# @ref CLGEMMLowpMatrixMultiplyCore (if quantized asymmetric)
- * -# @ref CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint (if quantized asymmetric)
- * -# @ref CLCol2ImKernel
- *
- * if the weights are already reshaped:
- * -# @ref CLGEMMInterleave4x4Kernel
- * -# @ref CLGEMMMatrixMultiplyKernel
- * else
- * -# @ref CLGEMM
+ * -# @ref CLGEMM (if the data type is FP32 or FP16)
+ * -# @ref CLGEMMLowpMatrixMultiplyCore (if the data type is QASYMM8)
+ * -# @ref CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint (if the data type is QASYMM8)
+ * -# @ref CLArithmeticAdditionKernel (if biases != nullptr and we have a 1x1 convolution with the NHWC data layout)
+ * -# @ref CLCol2ImKernel (if NCHW data layout)
*/
class CLGEMMConvolutionLayer : public IFunction
{
@@ -172,10 +166,11 @@ private:
* @param[in] output Output tensor. Data types supported: Same as @p input,
* except for input of QASYMM8 type where output should be of S32 type.
* @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
+ * @param[in] skip_im2col (Optional) Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. (Default to false)
*
* @return a status
*/
- static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1);
+ static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1, bool skip_im2col = false);
private:
CLMemoryGroup _memory_group;
@@ -186,6 +181,7 @@ private:
CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
CLCol2ImKernel _col2im_kernel;
CLActivationLayer _activationlayer_function;
+ CLArithmeticAdditionKernel _add_bias_kernel;
const ICLTensor *_original_weights;
@@ -196,6 +192,7 @@ private:
DataLayout _data_layout;
+ bool _append_bias;
bool _skip_im2col;
bool _is_quantized;
bool _is_activationlayer_enabled;
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 5a6efe64b9..932e0d681a 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -88,6 +88,11 @@ __kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
*
* @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
* @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
*
* @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
* @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -105,9 +110,15 @@ __kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
*/
__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst))
+ TENSOR3D_DECLARATION(dst)
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+ )
{
// Compute source and destination addresses
uint x = get_global_id(0);
@@ -124,6 +135,45 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
// Add offset for batched GEMM
dst_addr_in_bytes += z * dst_stride_z;
+#if defined(REINTERPRET_INPUT_AS_3D)
+ __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
+
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (cross_plane_pad * src_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src_stride_z by DEPTH_GEMM3D
+ input_ptr += z * src_stride_z * DEPTH_GEMM3D;
+
+ // Load values from Matrix A
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
+#else // defined(REINTERPRET_INPUT_AS_3D)
__global uchar *input_ptr = src.ptr;
// Load values from Matrix A
@@ -135,6 +185,7 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
VEC_DATA_TYPE(DATA_TYPE, 4)
a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
VEC_DATA_TYPE(DATA_TYPE, 4)
val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
@@ -188,7 +239,7 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -366,7 +417,7 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -679,7 +730,7 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -853,7 +904,7 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1095,7 +1146,8 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0)
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
* -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1122,7 +1174,8 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0)
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1130,9 +1183,13 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
uint src0_stride_z,
uint src1_stride_z,
uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
- uint cross_plane_pad
+ uint dst_cross_plane_pad
#endif // REINTERPRET_OUTPUT_AS_3D
)
{
@@ -1147,9 +1204,40 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(DATA_TYPE);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1172,6 +1260,23 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
VEC_DATA_TYPE(DATA_TYPE, 2)
a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
@@ -1187,6 +1292,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
VEC_DATA_TYPE(DATA_TYPE, 2)
a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
@@ -1210,6 +1317,19 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1221,6 +1341,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
@@ -1280,7 +1402,7 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
zout = min(DEPTH_GEMM3D - 1, zout);
// Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
+ zout *= (dst_cross_plane_pad * dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -1335,7 +1457,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
* -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1362,7 +1485,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1370,9 +1494,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
uint src0_stride_z,
uint src1_stride_z,
uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
- uint cross_plane_pad
+ uint dst_cross_plane_pad
#endif // REINTERPRET_OUTPUT_AS_3D
)
{
@@ -1387,9 +1515,40 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
// Update address for matrix B
src_addr.s1 += idx * sizeof(float);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1428,6 +1587,19 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
int i = 0;
for(; i <= ((int)COLS_A - 4); i += 4)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A and matrix B
+ float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A and matrix B
float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1439,6 +1611,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1579,8 +1753,21 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
- float a0 = *((__global float *)(src0_ptr + src_addr.s0));
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1590,6 +1777,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1676,7 +1865,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
zout = min(DEPTH_GEMM3D - 1, zout);
// Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
+ zout *= (dst_cross_plane_pad * dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -1723,7 +1912,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
* -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1750,7 +1940,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -1758,9 +1949,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
uint src0_stride_z,
uint src1_stride_z,
uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
- uint cross_plane_pad
+ uint dst_cross_plane_pad
#endif // REINTERPRET_OUTPUT_AS_3D
)
{
@@ -1776,9 +1971,40 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(float);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1807,8 +2033,13 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
int i = 0;
for(; i <= ((int)COLS_A - 8); i += 8)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix B
float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
@@ -1848,7 +2079,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc01 = fma(a0.s7, b7.s1, acc01);
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc10 = fma(a0.s0, b0.s0, acc10);
acc10 = fma(a0.s1, b1.s0, acc10);
acc10 = fma(a0.s2, b2.s0, acc10);
@@ -1868,7 +2103,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc11 = fma(a0.s7, b7.s1, acc11);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc20 = fma(a0.s0, b0.s0, acc20);
acc20 = fma(a0.s1, b1.s0, acc20);
acc20 = fma(a0.s2, b2.s0, acc20);
@@ -1888,7 +2127,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
acc21 = fma(a0.s7, b7.s1, acc21);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
acc30 = fma(a0.s0, b0.s0, acc30);
acc30 = fma(a0.s1, b1.s0, acc30);
acc30 = fma(a0.s2, b2.s0, acc30);
@@ -1913,6 +2156,19 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
// float size increment
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1924,6 +2180,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -1994,7 +2252,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
zout = min(DEPTH_GEMM3D - 1, zout);
// Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
+ zout *= (dst_cross_plane_pad * dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -2041,7 +2299,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
* This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
* -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -2068,7 +2327,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] cross_plane_pad Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
IMAGE_DECLARATION(src1),
@@ -2076,9 +2336,13 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
uint src0_stride_z,
uint src1_stride_z,
uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
- uint cross_plane_pad
+ uint dst_cross_plane_pad
#endif // REINTERPRET_OUTPUT_AS_3D
)
{
@@ -2093,9 +2357,40 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
// Update address for the matrix B
src_addr.s1 += idx * sizeof(half);
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -2117,6 +2412,19 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
int i = 0;
for(; i <= ((int)COLS_A - 4); i += 4)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -2128,6 +2436,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
src_addr.s1 += src1_stride_y;
@@ -2188,6 +2498,19 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
for(; i < (int)COLS_A; ++i)
{
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -2199,6 +2522,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
// Load values from matrix B
half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
@@ -2260,7 +2585,7 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
zout = min(DEPTH_GEMM3D - 1, zout);
// Add offset due to the cross plane paddings
- zout *= (cross_plane_pad * dst_stride_y);
+ zout *= (dst_cross_plane_pad * dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
diff --git a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
index 12a40cd7dc..6ea1160c69 100644
--- a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
+++ b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
+#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/CLValidate.h"
@@ -30,6 +31,7 @@
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Window.h"
@@ -40,7 +42,7 @@ using namespace arm_compute::misc::shape_calculator;
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
ARM_COMPUTE_RETURN_ERROR_ON(mult_interleave4x4_height < 1);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
@@ -50,24 +52,30 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, i
if(output->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_interleaved_shape(*input, mult_interleave4x4_height));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_interleaved_shape(*input, mult_interleave4x4_height, reinterpret_input_as_3d));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
}
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, int mult_interleave4x4_height)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
constexpr unsigned int num_elems_processed_per_iteration_x = 4;
constexpr unsigned int num_elems_processed_per_iteration_y = 4;
const unsigned int num_elems_written_per_iteration = num_elems_processed_per_iteration_x * num_elems_processed_per_iteration_y * mult_interleave4x4_height;
bool window_changed = false;
- // Configure kernel window
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowRectangle input_access(input, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
- window_changed = window_changed || update_window_and_padding(win, input_access);
+ TensorInfo tmp_info(*input);
+
+ if(reinterpret_input_as_3d)
+ {
+ // Since the input tensor has to be reinterpreted as 3D and the execute window is based on a 2D interleave,
+ // the window needs to be constructed on the 2D collapsed version of the tensor
+ TensorShape tmp_shape(input->tensor_shape());
+ tmp_shape.collapse(2U, 1U);
+ tmp_info.set_tensor_shape(tmp_shape);
+ }
// Output auto inizialitation if not yet initialized
auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_interleaved_shape(*input, mult_interleave4x4_height)));
@@ -76,9 +84,22 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
const float scale_x = 4.0f * static_cast<float>(mult_interleave4x4_height);
const float scale_y = 1.0f / (scale_x);
+ // Note: bottom paddings are calculated manually as the input can be reinterpreted as 3D tensor
+ // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+ const int m = reinterpret_input_as_3d ? input->tensor_shape()[1] * input->tensor_shape()[2] : input->tensor_shape()[1];
+ const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
+ Window win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ Window win_in = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+
+ AccessWindowStatic input_access(input, 0, 0,
+ ceil_to_multiple(input->dimension(0), num_elems_processed_per_iteration_x),
+ input->dimension(1) + bottom_pad);
AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration, 1, scale_x, scale_y);
- window_changed = window_changed || update_window_and_padding(win, output_access);
- output_access.set_valid_region(win, input->valid_region());
+
+ window_changed = update_window_and_padding(win_in, input_access) || // window used by the execute_window_loop
+ update_window_and_padding(win, output_access); // window used to update the padding requirements of output tensor
+ output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
// Collapse along the Z direction
// This collapse needs to be here in order to tune the Z dimension of LWS
@@ -90,26 +111,31 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
} // namespace
CLGEMMInterleave4x4Kernel::CLGEMMInterleave4x4Kernel()
- : _input(nullptr), _output(nullptr)
+ : _input(nullptr), _output(nullptr), _reinterpret_input_as_3d(false)
{
}
-void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height)
+void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_interleaved_shape(*input->info(), mult_interleave4x4_height)));
+ auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_interleaved_shape(*input->info(), mult_interleave4x4_height, reinterpret_input_as_3d)));
// Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mult_interleave4x4_height));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mult_interleave4x4_height, reinterpret_input_as_3d));
- _input = input;
- _output = output;
+ _input = input;
+ _output = output;
+ _reinterpret_input_as_3d = reinterpret_input_as_3d;
// Create build options
CLBuildOptions build_opts;
build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(input->info()->dimension(1)));
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(input->info()->dimension(2)));
+
switch(input->info()->element_size())
{
case 1:
@@ -129,12 +155,13 @@ void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *out
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_interleave4x4", build_opts.options()));
// Configure kernel window
- auto win_config = validate_and_configure_window(input->info(), output->info(), mult_interleave4x4_height);
+ auto win_config = validate_and_configure_window(input->info(), output->info(), mult_interleave4x4_height, reinterpret_input_as_3d);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure(win_config.second);
// Set config_id for enabling LWS tuning
_config_id = "interleave4x4_";
+ _config_id += (_reinterpret_input_as_3d ? "3d_" : "");
_config_id += lower_string(string_from_data_type(input->info()->data_type()));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(0));
@@ -146,10 +173,10 @@ void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *out
_config_id += support::cpp11::to_string(output->info()->dimension(3));
}
-Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
+Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mult_interleave4x4_height));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), mult_interleave4x4_height).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mult_interleave4x4_height, reinterpret_input_as_3d));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), mult_interleave4x4_height, reinterpret_input_as_3d).first);
return Status{};
}
@@ -170,6 +197,14 @@ void CLGEMMInterleave4x4Kernel::run(const Window &window, cl::CommandQueue &queu
*/
Window slice = window.first_slice_window_3D();
+ if(_reinterpret_input_as_3d)
+ {
+ // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor
+ const unsigned int idx0 = 2 * num_arguments_per_3D_tensor();
+ const unsigned int total_cross_plane_pad = _input->info()->padding().top + _input->info()->padding().bottom;
+ _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+ }
+
do
{
unsigned int idx = 0;
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 0c629af788..c9e6bb34b2 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -56,6 +56,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
if(!is_interleaved_transposed)
{
@@ -125,6 +126,9 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
if(is_interleaved_transposed)
{
+ // reinterpret_input_as_3d is not supported if is_interleaved_transposed is set
+ ARM_COMPUTE_ERROR_ON(reshape_info.reinterpret_input_as_3d());
+
// Configure kernel window
num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
num_elems_processed_per_iteration_y = 4;
@@ -158,7 +162,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
// Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
// The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
- const int m = input0->tensor_shape()[1];
+ const int m = reshape_info.reinterpret_input_as_3d() ? input0->tensor_shape()[1] * input0->tensor_shape()[2] : input0->tensor_shape()[1];
const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
// Create kernels according to the architecture, data type and input size.
@@ -172,7 +176,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
- AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), ceil_to_multiple(input0->dimension(1), num_elems_processed_per_iteration_y));
+ AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), input0->dimension(1) + bottom_pad);
AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
AccessWindowStatic output_access(output, 0, 0,
ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
@@ -198,7 +202,7 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
} // namespace
CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
- : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _is_gemm3d(false)
+ : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false)
{
}
@@ -209,19 +213,22 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Perform validate step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
- _input0 = input0;
- _input1 = input1;
- _output = output;
- _slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions();
+ _input0 = input0;
+ _input1 = input1;
+ _output = output;
+ _reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d();
+ _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 1);
+
+ // Check if we need to slide the matrix B
+ const unsigned int num_dimensions_input0 = _reinterpret_input_as_3d ? _input0->info()->num_dimensions() - 1 : _input0->info()->num_dimensions();
+
+ _slide_matrix_b = (_input1->info()->num_dimensions() >= num_dimensions_input0);
const DataType data_type = input0->info()->data_type();
// Get target architecture
GPUTarget gpu_target = get_target();
- // Check if the output has to be reinterpreted as 3D
- _is_gemm3d = (reshape_info.depth_output_gemm3d() != 1) && is_data_type_float(data_type);
-
ElementsProcessed num_elements_processed{};
// Configure kernel window
@@ -237,9 +244,10 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
{
build_opts.add_option("-DALPHA=" + float_to_string_with_full_precision(alpha));
}
- build_opts.add_option_if(_is_gemm3d, "-DREINTERPRET_OUTPUT_AS_3D");
- build_opts.add_option_if(_is_gemm3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
- build_opts.add_option_if(_is_gemm3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
+ build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D");
+ build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
+ build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
// Do not slide matrix B if _slide_matrix_b = false
build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
@@ -305,7 +313,8 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Set config_id for enabling LWS tuning
_config_id = "gemm_";
_config_id += (is_interleaved_transposed ? "reshaped_" : "");
- _config_id += (_is_gemm3d ? "3d_" : "");
+ _config_id += (_reinterpret_input_as_3d ? "3di_" : "");
+ _config_id += (_reinterpret_output_as_3d ? "3do_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
@@ -355,10 +364,18 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que
slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
- if(_is_gemm3d)
+ if(_reinterpret_input_as_3d)
{
// Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3;
+ const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom;
+ _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+ }
+
+ if(_reinterpret_output_as_3d)
+ {
+ // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
+ const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom;
_kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
}
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 1f4df4f1a9..1d1b17bbf1 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -102,7 +102,8 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
// in order to know how the matrices have been reshaped
- const int m = a->info()->dimension(1);
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ const int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
const int n = b->info()->dimension(0);
const int k = a->info()->dimension(0);
const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
@@ -118,6 +119,12 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
// Check if we need to reshape the matrix A and matrix B
_is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
+ // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+ if(_is_interleaved_transposed)
+ {
+ reinterpret_input_as_3d = false;
+ }
+
if(_is_interleaved_transposed)
{
matrix_a = &_tmp_a;
@@ -132,14 +139,15 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
// _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
// Configure interleave kernel
- _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height);
+ _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
// Configure transpose kernel
_transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
}
// Configure and tune matrix multiply kernel
- _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d));
+ _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d,
+ reinterpret_input_as_3d));
CLScheduler::get().tune_kernel_static(_mm_kernel);
if(_is_interleaved_transposed)
@@ -180,11 +188,13 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
// in order to know how the matrices have been reshaped
- const int m = a->dimension(1);
+ bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+ const int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
const int n = b->dimension(0);
const int k = a->dimension(0);
int mult_transpose1xW_width = 1;
int mult_interleave4x4_height = 1;
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
{
@@ -192,19 +202,25 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
mult_interleave4x4_height = 2;
}
- const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
-
// Check if we need to reshape the matrix A and matrix B
const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
+ // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+ if(run_interleave_transpose)
+ {
+ reinterpret_input_as_3d = false;
+ }
+
+ const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+
if(run_interleave_transpose)
{
matrix_a_info = &tmp_a_info;
matrix_b_info = &tmp_b_info;
// Validate interleave kernel
- auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height)));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height));
+ auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
// Validate transpose kernel
auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index f1d2924c92..de628291eb 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -91,15 +91,15 @@ void CLConvolutionLayerReshapeWeights::run()
CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(),
- _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _skip_im2col(false), _is_quantized(false),
- _is_activationlayer_enabled(false), _is_prepared(false)
+ _add_bias_kernel(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false),
+ _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
{
}
void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
- ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info()));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info(), _skip_im2col));
if(_is_quantized)
{
@@ -120,15 +120,16 @@ void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTenso
else
{
// Configure matrix multiply function
- _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth));
+ _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth,
+ _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */));
}
}
-Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth)
+Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth, bool skip_im2col)
{
const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
- const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth);
+ const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
if(is_quantized)
{
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -180,7 +181,8 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
_original_weights = weights;
_is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
_data_layout = data_layout;
- _skip_im2col = false;
+ _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1) && !_is_quantized;
+ _append_bias = (biases != nullptr) && (!_is_quantized);
// Set the GPU target for im2col and col2im
_im2col_kernel.set_target(CLScheduler::get().target());
@@ -191,9 +193,8 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
ICLTensor *gemm_output_to_use = output;
ICLTensor *gemm_output_staged_to_use = output;
- const bool append_bias = (biases != nullptr) && (!_is_quantized);
- const unsigned bias_element = (append_bias) ? 1 : 0;
- const ICLTensor *biases_to_use = (append_bias) ? biases : nullptr;
+ const unsigned bias_element = (_append_bias && !_skip_im2col) ? 1 : 0;
+ const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
// Get parameters from conv_info
unsigned int stride_x = 0;
@@ -238,12 +239,17 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
_memory_group.manage(&_im2col_output);
// Configure and tune im2col
- _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation);
+ _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation);
CLScheduler::get().tune_kernel_static(_im2col_kernel);
// Update GEMM input
gemm_input_to_use = &_im2col_output;
}
+ else if(_append_bias)
+ {
+ // Configure add bias kernel
+ _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
+ }
// Create GEMM output tensor
if(!is_nhwc || _is_quantized)
@@ -281,28 +287,23 @@ void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *
float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
int output_multiplier, output_shift;
quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- if(!is_nhwc)
- {
- _memory_group.manage(&_tmp_output);
- gemm_output_staged_to_use = &_tmp_output;
- }
+
+ _memory_group.manage(&_tmp_output);
+ gemm_output_staged_to_use = &_tmp_output;
+
_gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset);
}
- if(!is_nhwc)
+ if(!is_nhwc || _is_quantized)
{
// Configure and tune Col2Im
_col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, std::make_pair(conv_w, conv_h));
CLScheduler::get().tune_kernel_static(_col2im_kernel);
}
- if(_is_quantized && !is_nhwc)
- {
- _tmp_output.allocator()->allocate();
- }
-
if(!is_nhwc || _is_quantized)
{
+ _tmp_output.allocator()->allocate();
_gemm_output.allocator()->allocate();
}
@@ -348,10 +349,10 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
const ITensorInfo *weights_to_use = weights;
const bool is_nhwc = data_layout == DataLayout::NHWC;
- const bool skip_im2col = false;
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
+ const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1) && !is_quantized;
const bool append_bias = (biases != nullptr) && (!is_quantized);
- const unsigned bias_element = (append_bias) ? 1 : 0;
+ const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0;
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -410,6 +411,11 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
gemm_input_to_use = &im2col_reshaped_info;
}
+ else if(append_bias)
+ {
+ // Validate add bias kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
+ }
// Create GEMM output tensor
if(!is_nhwc || is_quantized)
@@ -424,25 +430,24 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
gemm_output_to_use = &info_gemm;
}
- ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1, skip_im2col));
if(is_quantized)
{
float multiplier = input->quantization_info().scale * weights_to_use->quantization_info().scale / output->quantization_info().scale;
int output_multiplier, output_shift;
quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- if(!is_nhwc)
- {
- tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
- tmp_info.set_quantization_info(output->quantization_info());
- gemm_output_staged_to_use = &tmp_info;
- }
+
+ tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
+ tmp_info.set_quantization_info(output->quantization_info());
+ gemm_output_staged_to_use = &tmp_info;
+
// Validate output stage for quantized case
CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, output->quantization_info().offset);
}
// Validate Col2Im
- if(!is_nhwc)
+ if(!is_nhwc || is_quantized)
{
ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use,
output,
@@ -485,8 +490,13 @@ void CLGEMMConvolutionLayer::run()
_mm_gemm.run();
}
+ if(_skip_im2col && _append_bias)
+ {
+ CLScheduler::get().enqueue(_add_bias_kernel);
+ }
+
// Reshape output matrix
- if(_data_layout == DataLayout::NCHW)
+ if(_data_layout == DataLayout::NCHW || _is_quantized)
{
CLScheduler::get().enqueue(_col2im_kernel, false);
}
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 842ee73397..c2e18a760a 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -205,16 +205,17 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
const int k = a->dimension(0);
constexpr int mult_transpose1xW_width = 1;
constexpr int mult_interleave4x4_height = 1;
- const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height);
+ const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
+ const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d);
bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
if(reshape_matrices)
{
- TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height), 1, a->data_type());
+ TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()), 1, a->data_type());
TensorInfo info_b(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width), 1, b->data_type());
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMTranspose1xWKernel::validate(b, &info_b, mult_transpose1xW_width));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output, reshape_matrices, reshape_info));
}
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index bc9a056e7c..9c45cb0657 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -48,18 +48,34 @@ public:
add_config(TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 0.4f, 0.7f);
}
};
-class LargeGEMM3DDataset final : public GEMMDataset
+
+class LargeGEMMOutput3DDataset final : public GEMMDataset
{
public:
- LargeGEMM3DDataset()
+ LargeGEMMOutput3DDataset()
{
add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
- add_config(TensorShape(364, 3025), TensorShape(96, 364), TensorShape(96, 605, 5), TensorShape(96, 605, 5), 1.0f, 0.0f);
- add_config(TensorShape(1201, 729), TensorShape(128, 1201), TensorShape(128, 243, 3), TensorShape(128, 243, 3), 1.0f, 0.0f);
- add_config(TensorShape(2305, 169), TensorShape(384, 2305), TensorShape(384, 13, 13), TensorShape(384, 13, 13), 1.0f, 0.0f);
- add_config(TensorShape(1729, 170), TensorShape(192, 1729), TensorShape(192, 85, 2), TensorShape(192, 85, 2), 1.0f, 0.0f);
- add_config(TensorShape(1729, 170), TensorShape(128, 1729), TensorShape(128, 17, 10), TensorShape(128, 17, 10), 1.0f, 0.0f);
+ add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f);
+ add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
+ add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f);
+ add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f);
+ }
+};
+
+class LargeGEMMInputOutput3DDataset final : public GEMMDataset
+{
+public:
+ LargeGEMMInputOutput3DDataset()
+ {
+ add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
+ add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
+ add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f);
+ add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
+ add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f);
+ add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U, 2U), TensorShape(192U, 85U, 2U, 2U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U, 3U), TensorShape(128U, 17U, 10U, 3U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.0f);
}
};
} // namespace datasets
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index c9bf674ad3..e108fcc1ca 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -49,10 +49,10 @@ public:
add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U), TensorShape(17U, 1U), 0.4f, 0.7f);
}
};
-class SmallGEMM3DDataset final : public GEMMDataset
+class SmallGEMMOutput3DDataset final : public GEMMDataset
{
public:
- SmallGEMM3DDataset()
+ SmallGEMMOutput3DDataset()
{
add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U, 7U, 2U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f);
add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U, 1U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f);
@@ -62,6 +62,20 @@ public:
add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U, 8U, 2U, 5U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f);
}
};
+
+class SmallGEMMInputOutput3DDataset final : public GEMMDataset
+{
+public:
+ SmallGEMMInputOutput3DDataset()
+ {
+ add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U, 14U, 13U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f);
+ add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U, 1U, 3U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f);
+ add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U, 12U, 2U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f);
+ add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U, 1U, 4U, 3U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f);
+ add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U, 16U, 3U, 2U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U, 3U), TensorShape(8U, 16U, 5U, 3U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.0f);
+ }
+};
} // namespace datasets
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/CL/GEMM.cpp b/tests/validation/CL/GEMM.cpp
index 639182030e..ff2071a756 100644
--- a/tests/validation/CL/GEMM.cpp
+++ b/tests/validation/CL/GEMM.cpp
@@ -107,6 +107,12 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::da
template <typename T>
using CLGEMMFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T>;
+template <typename T>
+using CLGEMMOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, true>;
+
+template <typename T>
+using CLGEMMInputOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, true, true>;
+
TEST_SUITE(TRANSPOSE_1XW)
using CLGEMMTranspose1xW = CLSynthetizeFunctionWithZeroConstantBorder<CLGEMMTranspose1xWKernel, 4>;
using CLGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture<CLTensor, CLAccessor, CLGEMMTranspose1xW, float>;
@@ -149,17 +155,53 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::N
TEST_SUITE_END()
TEST_SUITE_END()
+TEST_SUITE(INPUT_OUTPUT_3D)
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMInputOutput3DDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMInputOutput3DDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+}
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMInputOutput3DDataset(),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMInputOutput3DDataset(),
+ framework::dataset::make("DataType",
+ DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+TEST_SUITE_END() // FP16
+
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // INPUT_OUTPUT_3D
+
TEST_SUITE(OUTPUT_3D)
TEST_SUITE(Float)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMM3DDataset(),
- framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMOutput3DDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMM3DDataset(),
- framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMOutput3DDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
@@ -167,15 +209,15 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::N
TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMM3DDataset(),
- framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMOutput3DDataset(),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMM3DDataset(),
- framework::dataset::make("DataType",
- DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMOutput3DDataset(),
+ framework::dataset::make("DataType",
+ DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index e4762cc5be..255b12c0ed 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -42,8 +42,8 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class GEMMValidationFixedPointFixture : public framework::Fixture
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool reinterpret_input_as_3d = false, bool reinterpret_ouput_as_3d = false>
+class GEMMValidationFixture : public framework::Fixture
{
public:
template <typename...>
@@ -87,10 +87,7 @@ protected:
// The GEMMinfo includes the values of the depth in case of reinterpreted 3d output.
// If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 1),
// in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output).
- bool is_output_reinterpreted_as_3D = output_shape.num_dimensions() > shape_a.num_dimensions();
- gemm.configure(&a, &b, &c, &dst, alpha, beta,
- GEMMInfo(false, false, false, is_output_reinterpreted_as_3D ? output_shape[2] : 1));
-
+ gemm.configure(&a, &b, &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 1), reinterpret_input_as_3d));
ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -121,8 +118,15 @@ protected:
SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &output_shape, float alpha, float beta,
DataType data_type)
{
+ TensorShape shape_a_to_use = shape_a;
+ if(reinterpret_input_as_3d)
+ {
+ // Collapse the second and third dimension if the input is 3D
+ shape_a_to_use.collapse(2U, 1U);
+ }
+
// Create reference
- SimpleTensor<T> a{ shape_a, data_type, 1 };
+ SimpleTensor<T> a{ shape_a_to_use, data_type, 1 };
SimpleTensor<T> b{ shape_b, data_type, 1 };
SimpleTensor<T> c{ shape_c, data_type, 1 };
@@ -139,16 +143,6 @@ protected:
DataType _data_type{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class GEMMValidationFixture : public GEMMValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>
-{
-public:
- template <typename...>
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, DataType data_type)
- {
- GEMMValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type);
- }
-};
} // namespace validation
} // namespace test
} // namespace arm_compute