aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h7
-rw-r--r--arm_compute/core/Types.h17
-rw-r--r--src/core/CL/CLKernelLibrary.cpp1
-rw-r--r--src/core/CL/cl_kernels/gemm.cl348
-rw-r--r--src/core/CL/cl_kernels/winograd_output_transform.cl366
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp19
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp5
-rw-r--r--src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp13
-rw-r--r--tests/SimpleTensor.h39
-rw-r--r--tests/validation/CL/Winograd.cpp21
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp2
-rw-r--r--tests/validation/fixtures/WinogradConvolutionLayerFixture.h41
-rw-r--r--utils/Utils.h52
13 files changed, 704 insertions, 227 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
index e030fa2d2a..f61c330de6 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h
@@ -59,9 +59,11 @@ public:
* @param[in] alpha Weight of the matrix product
* @param[in] is_interleaved_transposed (Optional) True if input0 and input1 have been reshaped respectively using @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel
* @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
+ * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy
*
*/
- void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed = true, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo());
+ void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed = true, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo(),
+ bool fp_mixed_precision = false);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyKernel
*
* @param[in] input0 Input tensor containing the Matrix A. Data types supported: F16/F32
@@ -71,11 +73,12 @@ public:
* @param[in] is_interleaved_transposed True if input0 and input1 have been reshaped respectively using @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel
* @param[in] reshape_info GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped
* @param[in] gpu_target GPU Target
+ * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy
*
* @return a status
*/
static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
- GPUTarget gpu_target);
+ GPUTarget gpu_target, bool fp_mixed_precision = false);
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index fb277584fd..4eb8129b62 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1593,7 +1593,8 @@ class GEMMInfo
public:
/** Default constructor */
GEMMInfo()
- : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage()
+ : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false),
+ _gemmlowp_output_stage(), _fp_mixed_precision(false)
{
}
/** Constructor
@@ -1607,12 +1608,13 @@ public:
* to perform 1x1 convolutions with the NHWC data layout)
* @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run
* @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info
+ * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
*
*/
GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
- GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo())
+ GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false)
: _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d),
- _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage)
+ _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -1673,6 +1675,14 @@ public:
{
return _gemmlowp_output_stage;
};
+ /** Flag which specifies if a wider accumulator should be used.
+ *
+ * @return True if a wider accumulator has to be used
+ */
+ bool fp_mixed_precision() const
+ {
+ return _fp_mixed_precision;
+ };
private:
const bool _is_a_reshaped;
@@ -1682,6 +1692,7 @@ private:
const bool _reinterpret_input_as_3d;
const bool _retain_internal_weights;
const GEMMLowpOutputStageInfo _gemmlowp_output_stage;
+ const bool _fp_mixed_precision;
};
/** Winograd information */
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index fde9608949..955844da3e 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -256,6 +256,7 @@ const std::map<std::string, std::string> CLKernelLibrary::_kernel_program_map =
{ "gemm_mm_interleaved_transposed_f32_bifrost", "gemm.cl" },
{ "gemm_mm_floating_point", "gemm.cl" },
{ "gemm_mm_floating_point_f16_bifrost", "gemm.cl" },
+ { "gemm_mm_floating_point_f16_bifrost_acc32", "gemm.cl" },
{ "gemm_mm_floating_point_f32_bifrost", "gemm.cl" },
{ "gemm_mm_floating_point_f32_bifrost_1000", "gemm.cl" },
{ "gemm_lc_vm_f32", "gemm.cl" },
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index d24f014f11..5d5cab6578 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -2299,6 +2299,354 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
*
+ * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
+ * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
+ * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
+ * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
+ * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ *
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
+ * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
+ * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ */
+__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
+ IMAGE_DECLARATION(src1),
+ IMAGE_DECLARATION(dst),
+ uint src0_stride_z,
+ uint src1_stride_z,
+ uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+ ,
+ uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ ,
+ uint dst_cross_plane_pad
+#endif // REINTERPRET_OUTPUT_AS_3D
+ )
+{
+ int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
+
+ // Compute starting address for matrix A and Matrix B
+ int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
+
+ // Update address for the matrix A
+ src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
+
+ // Update address for the matrix B
+ src_addr.s1 += idx * sizeof(half);
+
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zin = min(DEPTH_GEMM3D - 1, zin);
+
+ // Add offset due to the cross plane paddings
+ zin *= (src_cross_plane_pad * src0_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply src0_stride_z by DEPTH_GEMM3D
+ src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
+ // Add offset for batched GEMM
+ src_addr.s0 += get_global_id(2) * src0_stride_z;
+
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
+ src_addr.s1 += get_global_id(2) * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
+
+ float8 acc0 = 0.0h;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ float8 acc1 = 0.0h;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ float8 acc2 = 0.0h;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ float8 acc3 = 0.0h;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ int i = 0;
+ for(; i <= ((int)COLS_A - 4); i += 4)
+ {
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
+ // Load values from matrix B
+ float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
+ src_addr.s1 += src1_stride_y;
+
+ // Accumulate
+ acc0 = fma(b0, (float8)a0.s0, acc0);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ acc1 = fma(b0, (float8)a1.s0, acc1);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ acc2 = fma(b0, (float8)a2.s0, acc2);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ acc3 = fma(b0, (float8)a3.s0, acc3);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
+ src_addr.s1 += src1_stride_y;
+ acc0 = fma(b0, (float8)a0.s1, acc0);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ acc1 = fma(b0, (float8)a1.s1, acc1);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ acc2 = fma(b0, (float8)a2.s1, acc2);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ acc3 = fma(b0, (float8)a3.s1, acc3);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
+ src_addr.s1 += src1_stride_y;
+ acc0 = fma(b0, (float8)a0.s2, acc0);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ acc1 = fma(b0, (float8)a1.s2, acc1);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ acc2 = fma(b0, (float8)a2.s2, acc2);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ acc3 = fma(b0, (float8)a3.s2, acc3);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
+ src_addr.s1 += src1_stride_y;
+ acc0 = fma(b0, (float8)a0.s3, acc0);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ acc1 = fma(b0, (float8)a1.s3, acc1);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ acc2 = fma(b0, (float8)a2.s3, acc2);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ acc3 = fma(b0, (float8)a3.s3, acc3);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ src_addr.s0 += 4 * sizeof(half);
+ }
+
+ for(; i < (int)COLS_A; ++i)
+ {
+#if defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else // defined(REINTERPRET_INPUT_AS_3D)
+ // Load values from matrix A
+ half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
+ // Load values from matrix B
+ float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
+
+ src_addr += (int2)(sizeof(half), src1_stride_y);
+
+ // Accumulate
+ acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ }
+
+ // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+ half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
+#else //defined(ALPHA)
+ half8 hacc0 = convert_half8(acc0);
+#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if defined(ALPHA)
+ half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
+#else //defined(ALPHA)
+ half8 hacc1 = convert_half8(acc1);
+#endif //defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if defined(ALPHA)
+ half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
+#else //defined(ALPHA)
+ half8 hacc2 = convert_half8(acc2);
+#endif //defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if defined(ALPHA)
+ half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
+#else //defined(ALPHA)
+ half8 hacc3 = convert_half8(acc3);
+#endif // defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+ int z = get_global_id(2);
+
+ // Compute destination address
+ Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+
+ // Compute dst address
+ __global uchar *dst_addr = offset(&dst, 0, 0);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+ // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+ // in order to take into account the presence of possible cross plane paddings
+ //
+ // | |
+ // | plane0 |
+ // | |
+ // |__________________|
+ // |******************|
+ // | cross_plane_pad |
+ // |******************|
+ // | |
+ // | plane1 |
+ // | |
+ // |__________________|
+
+ // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+ uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+ zout = min(DEPTH_GEMM3D - 1, zout);
+
+ // Add offset due to the cross plane paddings
+ zout *= (dst_cross_plane_pad * dst_stride_y);
+
+ // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+ // multiply dst_stride_z by DEPTH_GEMM3D
+ dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+ // Store the output block
+ vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+ // Add offset for batched GEMM
+ dst_addr += z * dst_stride_z;
+
+ // Store the output block
+ vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // REINTERPRET_OUTPUT_AS_3D
+}
+
+/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
+ *
* @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
* @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
* This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
diff --git a/src/core/CL/cl_kernels/winograd_output_transform.cl b/src/core/CL/cl_kernels/winograd_output_transform.cl
index 2c7c05fdd1..f52b027420 100644
--- a/src/core/CL/cl_kernels/winograd_output_transform.cl
+++ b/src/core/CL/cl_kernels/winograd_output_transform.cl
@@ -83,8 +83,8 @@ __kernel void winograd_output_transform_2x2_3x3_nchw(
// out00 = d00 + d01 + d02
// out01 = d01 - d02 - d03
- DATA_TYPE out00 = d00 + d01 + d02;
- DATA_TYPE out01 = d01 - d02 - d03;
+ float out00 = d00 + d01 + d02;
+ float out01 = d01 - d02 - d03;
#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
@@ -102,20 +102,20 @@ __kernel void winograd_output_transform_2x2_3x3_nchw(
DATA_TYPE d33 = *((__global DATA_TYPE *)(src_addr + 15 * src_stride_z));
// Compute the 2x2 output tile
- DATA_TYPE k0 = d01 + d11 + d21;
- DATA_TYPE k1 = d02 + d12 + d22;
- DATA_TYPE k2 = d11 - d21 - d31;
- DATA_TYPE k3 = d12 - d22 - d32;
+ float k0 = d01 + d11 + d21;
+ float k1 = d02 + d12 + d22;
+ float k2 = d11 - d21 - d31;
+ float k3 = d12 - d22 - d32;
// out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22
// out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23)
// out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32)
// out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33)
- DATA_TYPE out00 = d10;
- DATA_TYPE out01 = -d13;
- DATA_TYPE out10 = d10;
- DATA_TYPE out11 = -d13;
+ float out00 = d10;
+ float out01 = -d13;
+ float out10 = d10;
+ float out11 = -d13;
out00 += d00 + d20 + k0 + k1;
out01 += k0 - k1 - (d03 + d23);
@@ -135,10 +135,10 @@ __kernel void winograd_output_transform_2x2_3x3_nchw(
// Add bias
Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
- DATA_TYPE b = (DATA_TYPE) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
+ float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
- out00 += (DATA_TYPE)b;
- out01 += (DATA_TYPE)b;
+ out00 += (float)b;
+ out01 += (float)b;
#endif // defined(HAS_BIAS)
// Get output address
@@ -150,8 +150,8 @@ __kernel void winograd_output_transform_2x2_3x3_nchw(
// Store the output tile
#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
- *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out00;
- *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out01;
+ *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = (DATA_TYPE)out00;
+ *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = (DATA_TYPE)out01;
#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
vstore2((VEC_DATA_TYPE(DATA_TYPE, 2))(out00, out01), 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
@@ -163,7 +163,7 @@ __kernel void winograd_output_transform_2x2_3x3_nchw(
out11 += (DATA_TYPE)b;
#endif // defined(HAS_BIAS)
- vstore2((VEC_DATA_TYPE(DATA_TYPE, 2))(out10, out11), 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
+ vstore2((VEC_DATA_TYPE(DATA_TYPE, 2))((DATA_TYPE)out10, (DATA_TYPE)out11), 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
}
@@ -225,10 +225,10 @@ __kernel void winograd_output_transform_4x4_3x3_nchw(
#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
// Compute out00, out01, out02 and out03
- DATA_TYPE out00 = d00 + d01 + d02 + d03 + d04;
- DATA_TYPE out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
- DATA_TYPE out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
- DATA_TYPE out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
+ float out00 = d00 + d01 + d02 + d03 + d04;
+ float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
+ float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
+ float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
DATA_TYPE d11 = *((__global DATA_TYPE *)(src_addr + 7 * src_stride_z));
@@ -266,13 +266,13 @@ __kernel void winograd_output_transform_4x4_3x3_nchw(
DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
// Compute out00, out01, out02 and out03
- DATA_TYPE out00 = d01 + d21 + d41 + d11 + d31;
- DATA_TYPE out01 = d01 + d21 + d41 + d11 + d31;
- DATA_TYPE out02 = d01 + d21 + d41 + d11 + d31;
- DATA_TYPE out03 = d01 + d21 + d41 + d11 + d31;
+ float out00 = (float)d01 + (float)d21 + (float)d41 + (float)d11 + (float)d31;
+ float out01 = (float)d01 + (float)d21 + (float)d41 + (float)d11 + (float)d31;
+ float out02 = (float)d01 + (float)d21 + (float)d41 + (float)d11 + (float)d31;
+ float out03 = (float)d01 + d21 + (float)d41 + (float)d11 + (float)d31;
- DATA_TYPE k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
- DATA_TYPE k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
+ float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
+ float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
out01 += k1 - d02 - d12 - d22 - d32 - d42;
@@ -280,10 +280,10 @@ __kernel void winograd_output_transform_4x4_3x3_nchw(
out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
// Compute out10, out11, out12 and out13
- DATA_TYPE out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- DATA_TYPE out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- DATA_TYPE out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- DATA_TYPE out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
@@ -294,10 +294,10 @@ __kernel void winograd_output_transform_4x4_3x3_nchw(
out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
// Compute out20, out21, out22 and out23
- DATA_TYPE out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- DATA_TYPE out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- DATA_TYPE out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- DATA_TYPE out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
@@ -308,10 +308,10 @@ __kernel void winograd_output_transform_4x4_3x3_nchw(
out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
// Compute out30, out31, out32 and out33
- DATA_TYPE out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- DATA_TYPE out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- DATA_TYPE out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- DATA_TYPE out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
@@ -334,12 +334,12 @@ __kernel void winograd_output_transform_4x4_3x3_nchw(
// Add bias
Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
- DATA_TYPE b = (DATA_TYPE) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
+ float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
- out00 += (DATA_TYPE)b;
- out01 += (DATA_TYPE)b;
- out02 += (DATA_TYPE)b;
- out03 += (DATA_TYPE)b;
+ out00 += (float)b;
+ out01 += (float)b;
+ out02 += (float)b;
+ out03 += (float)b;
#endif // defined(HAS_BIAS)
// Get output address
@@ -351,35 +351,35 @@ __kernel void winograd_output_transform_4x4_3x3_nchw(
// Store the output tile
#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
- *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = out00;
- *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = out01;
- *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)) = out02;
- *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)) = out03;
+ *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)) = (DATA_TYPE)out00;
+ *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)) = (DATA_TYPE)out01;
+ *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)) = (DATA_TYPE)out02;
+ *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)) = (DATA_TYPE)out03;
#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
- vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))(out00, out01, out02, out03), 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
+ vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))((DATA_TYPE)out00, (DATA_TYPE)out01, (DATA_TYPE)out02, (DATA_TYPE)out03), 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
#if defined(HAS_BIAS)
// Add bias
- out10 += (DATA_TYPE)b;
- out11 += (DATA_TYPE)b;
- out12 += (DATA_TYPE)b;
- out13 += (DATA_TYPE)b;
-
- out20 += (DATA_TYPE)b;
- out21 += (DATA_TYPE)b;
- out22 += (DATA_TYPE)b;
- out23 += (DATA_TYPE)b;
-
- out30 += (DATA_TYPE)b;
- out31 += (DATA_TYPE)b;
- out32 += (DATA_TYPE)b;
- out33 += (DATA_TYPE)b;
+ out10 += (float)b;
+ out11 += (float)b;
+ out12 += (float)b;
+ out13 += (float)b;
+
+ out20 += (float)b;
+ out21 += (float)b;
+ out22 += (float)b;
+ out23 += (float)b;
+
+ out30 += (float)b;
+ out31 += (float)b;
+ out32 += (float)b;
+ out33 += (float)b;
#endif // defined(HAS_BIAS)
- vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))(out10, out11, out12, out13), 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
- vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))(out20, out21, out22, out23), 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
- vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))(out30, out31, out32, out33), 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
+ vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))((DATA_TYPE)out10, (DATA_TYPE)out11, (DATA_TYPE)out12, (DATA_TYPE)out13), 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
+ vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))((DATA_TYPE)out20, (DATA_TYPE)out21, (DATA_TYPE)out22, (DATA_TYPE)out23), 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
+ vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))((DATA_TYPE)out30, (DATA_TYPE)out31, (DATA_TYPE)out32, (DATA_TYPE)out33), 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
}
@@ -441,10 +441,10 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
// Compute out00, out01, out02 and out03
- DATA_TYPE out00 = d00 + d01 + d02 + d03 + d04;
- DATA_TYPE out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
- DATA_TYPE out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
- DATA_TYPE out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
+ float out00 = d00 + d01 + d02 + d03 + d04;
+ float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
+ float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
+ float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
DATA_TYPE d10 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
@@ -483,13 +483,13 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
DATA_TYPE d55 = *((__global DATA_TYPE *)(src_addr + 35 * src_stride_z));
// Compute out00, out01, out02 and out03
- DATA_TYPE out00 = d01 + d21 + d41 + d11 + d31;
- DATA_TYPE out01 = d01 + d21 + d41 + d11 + d31;
- DATA_TYPE out02 = d01 + d21 + d41 + d11 + d31;
- DATA_TYPE out03 = d01 + d21 + d41 + d11 + d31;
+ float out00 = d01 + d21 + d41 + d11 + d31;
+ float out01 = d01 + d21 + d41 + d11 + d31;
+ float out02 = d01 + d21 + d41 + d11 + d31;
+ float out03 = d01 + d21 + d41 + d11 + d31;
- DATA_TYPE k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
- DATA_TYPE k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
+ float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
+ float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
out01 += k1 - d02 - d12 - d22 - d32 - d42;
@@ -497,10 +497,10 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
// Compute out10, out11, out12 and out13
- DATA_TYPE out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- DATA_TYPE out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- DATA_TYPE out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
- DATA_TYPE out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
@@ -511,10 +511,10 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
// Compute out20, out21, out22 and out23
- DATA_TYPE out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- DATA_TYPE out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- DATA_TYPE out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
- DATA_TYPE out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
@@ -525,10 +525,10 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
// Compute out30, out31, out32 and out33
- DATA_TYPE out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- DATA_TYPE out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- DATA_TYPE out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
- DATA_TYPE out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
@@ -585,19 +585,19 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, (int4)dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
// Store the 1x4 output tile
- *((__global DATA_TYPE *)(dst_ptr + offset.s0)) = out00;
- *((__global DATA_TYPE *)(dst_ptr + offset.s1)) = out01;
- *((__global DATA_TYPE *)(dst_ptr + offset.s2)) = out02;
- *((__global DATA_TYPE *)(dst_ptr + offset.s3)) = out03;
+ *((__global DATA_TYPE *)(dst_ptr + offset.s0)) = (DATA_TYPE)out00;
+ *((__global DATA_TYPE *)(dst_ptr + offset.s1)) = (DATA_TYPE)out01;
+ *((__global DATA_TYPE *)(dst_ptr + offset.s2)) = (DATA_TYPE)out02;
+ *((__global DATA_TYPE *)(dst_ptr + offset.s3)) = (DATA_TYPE)out03;
#elif defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
// Store the 4x1 output tile
int offset = dst_offset_first_element_in_bytes + x_out * sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
int mult_y = min(dst_size - offset, 1);
- *((__global DATA_TYPE *)(dst_ptr + mult_y * 0 * dst_stride_y + offset)) = out00;
- *((__global DATA_TYPE *)(dst_ptr + mult_y * 1 * dst_stride_y + offset)) = out01;
- *((__global DATA_TYPE *)(dst_ptr + mult_y * 2 * dst_stride_y + offset)) = out02;
- *((__global DATA_TYPE *)(dst_ptr + mult_y * 3 * dst_stride_y + offset)) = out03;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y * 0 * dst_stride_y + offset)) = (DATA_TYPE)out00;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y * 1 * dst_stride_y + offset)) = (DATA_TYPE)out01;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y * 2 * dst_stride_y + offset)) = (DATA_TYPE)out02;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y * 3 * dst_stride_y + offset)) = (DATA_TYPE)out03;
#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
// Get output address
#if defined(SRC_DEPTH)
@@ -609,22 +609,22 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
int4 mult_y = min((int4)dst_size - offset, (int4)1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise.
// Store the 4x4 output tile
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0)) = out00;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0)) = out01;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0)) = out02;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0)) = out03;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 0 * dst_stride_y + offset.s1)) = out10;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 1 * dst_stride_y + offset.s1)) = out11;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 2 * dst_stride_y + offset.s1)) = out12;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 3 * dst_stride_y + offset.s1)) = out13;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 0 * dst_stride_y + offset.s2)) = out20;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 1 * dst_stride_y + offset.s2)) = out21;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 2 * dst_stride_y + offset.s2)) = out22;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 3 * dst_stride_y + offset.s2)) = out23;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 0 * dst_stride_y + offset.s3)) = out30;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 1 * dst_stride_y + offset.s3)) = out31;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 2 * dst_stride_y + offset.s3)) = out32;
- *((__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 3 * dst_stride_y + offset.s3)) = out33;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0)) = (DATA_TYPE)out00;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0)) = (DATA_TYPE)out01;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0)) = (DATA_TYPE)out02;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0)) = (DATA_TYPE)out03;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 0 * dst_stride_y + offset.s1)) = (DATA_TYPE)out10;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 1 * dst_stride_y + offset.s1)) = (DATA_TYPE)out11;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 2 * dst_stride_y + offset.s1)) = (DATA_TYPE)out12;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 3 * dst_stride_y + offset.s1)) = (DATA_TYPE)out13;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 0 * dst_stride_y + offset.s2)) = (DATA_TYPE)out20;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 1 * dst_stride_y + offset.s2)) = (DATA_TYPE)out21;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 2 * dst_stride_y + offset.s2)) = (DATA_TYPE)out22;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 3 * dst_stride_y + offset.s2)) = (DATA_TYPE)out23;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 0 * dst_stride_y + offset.s3)) = (DATA_TYPE)out30;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 1 * dst_stride_y + offset.s3)) = (DATA_TYPE)out31;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 2 * dst_stride_y + offset.s3)) = (DATA_TYPE)out32;
+ *((__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 3 * dst_stride_y + offset.s3)) = (DATA_TYPE)out33;
#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
}
@@ -721,16 +721,16 @@ __kernel void winograd_output_transform_4x4_5x5_nchw(
#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
// Compute out00, out01, out02 and out03
- DATA_TYPE out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
- DATA_TYPE out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
- DATA_TYPE out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
- DATA_TYPE out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
+ float out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
+ float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
+ float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
+ float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
#if defined(HAS_BIAS)
// Add bias
Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
- DATA_TYPE b = (DATA_TYPE) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
+ float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
out00 += (DATA_TYPE)b;
out01 += (DATA_TYPE)b;
@@ -813,9 +813,9 @@ __kernel void winograd_output_transform_4x4_5x5_nchw(
DATA_TYPE d77 = *((__global DATA_TYPE *)(src_addr + 63 * src_stride_z));
// Compute the 8x4 intermediate tensor
- VEC_DATA_TYPE(DATA_TYPE, 4)
+ VEC_DATA_TYPE(float, 4)
comm_fact0, comm_fact1, comm_fact2;
- VEC_DATA_TYPE(DATA_TYPE, 4)
+ VEC_DATA_TYPE(float, 4)
tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
@@ -832,37 +832,37 @@ __kernel void winograd_output_transform_4x4_5x5_nchw(
comm_fact1 = tmp_col3 + tmp_col4;
comm_fact2 = tmp_col5 + tmp_col6;
- VEC_DATA_TYPE(DATA_TYPE, 4)
- out_col0 = comm_fact0 + comm_fact1 + (DATA_TYPE)8.f * comm_fact2 + tmp_col0;
- VEC_DATA_TYPE(DATA_TYPE, 4)
- out_col2 = comm_fact0 + (DATA_TYPE)4.f * comm_fact1 + (DATA_TYPE)2.f * comm_fact2;
+ VEC_DATA_TYPE(float, 4)
+ out_col0 = comm_fact0 + comm_fact1 + (float)8.f * comm_fact2 + tmp_col0;
+ VEC_DATA_TYPE(float, 4)
+ out_col2 = comm_fact0 + (float)4.f * comm_fact1 + (float)2.f * comm_fact2;
comm_fact0 = tmp_col1 - tmp_col2;
comm_fact1 = tmp_col3 - tmp_col4;
comm_fact2 = tmp_col5 - tmp_col6;
- VEC_DATA_TYPE(DATA_TYPE, 4)
- out_col1 = comm_fact0 + (DATA_TYPE)2.f * comm_fact1 + (DATA_TYPE)4.f * comm_fact2;
- VEC_DATA_TYPE(DATA_TYPE, 4)
- out_col3 = comm_fact0 + (DATA_TYPE)8.f * comm_fact1 + comm_fact2 + tmp_col7;
+ VEC_DATA_TYPE(float, 4)
+ out_col1 = comm_fact0 + (float)2.f * comm_fact1 + (float)4.f * comm_fact2;
+ VEC_DATA_TYPE(float, 4)
+ out_col3 = comm_fact0 + (float)8.f * comm_fact1 + comm_fact2 + tmp_col7;
#if defined(HAS_BIAS)
// Add bias
Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
- DATA_TYPE b = (DATA_TYPE) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
+ float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, z_out)));
- out_col0 += (VEC_DATA_TYPE(DATA_TYPE, 4))b;
- out_col1 += (VEC_DATA_TYPE(DATA_TYPE, 4))b;
- out_col2 += (VEC_DATA_TYPE(DATA_TYPE, 4))b;
- out_col3 += (VEC_DATA_TYPE(DATA_TYPE, 4))b;
+ out_col0 += (VEC_DATA_TYPE(float, 4))b;
+ out_col1 += (VEC_DATA_TYPE(float, 4))b;
+ out_col2 += (VEC_DATA_TYPE(float, 4))b;
+ out_col3 += (VEC_DATA_TYPE(float, 4))b;
#endif // defined(HAS_BIAS)
// Store the output tile
- vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))(out_col0.s0, out_col1.s0, out_col2.s0, out_col3.s0), 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
- vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))(out_col0.s1, out_col1.s1, out_col2.s1, out_col3.s1), 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
- vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))(out_col0.s2, out_col1.s2, out_col2.s2, out_col3.s2), 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
- vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))(out_col0.s3, out_col1.s3, out_col2.s3, out_col3.s3), 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
+ vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))((DATA_TYPE)out_col0.s0, (DATA_TYPE)out_col1.s0, (DATA_TYPE)out_col2.s0, (DATA_TYPE)out_col3.s0), 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
+ vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))((DATA_TYPE)out_col0.s1, (DATA_TYPE)out_col1.s1, (DATA_TYPE)out_col2.s1, (DATA_TYPE)out_col3.s1), 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
+ vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))((DATA_TYPE)out_col0.s2, (DATA_TYPE)out_col1.s2, (DATA_TYPE)out_col2.s2, (DATA_TYPE)out_col3.s2), 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
+ vstore4((VEC_DATA_TYPE(DATA_TYPE, 4))((DATA_TYPE)out_col0.s3, (DATA_TYPE)out_col1.s3, (DATA_TYPE)out_col2.s3, (DATA_TYPE)out_col3.s3), 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
}
@@ -933,21 +933,21 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
// Compute out00, out01, out02 and out03
- DATA_TYPE out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
- DATA_TYPE out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
- DATA_TYPE out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
- DATA_TYPE out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
+ float out00 = d00 + d01 + d02 + d03 + d04 + 8.0f * d05 + 8.0f * d06;
+ float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04 + 4.0f * d05 - 4.0f * d06;
+ float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04 + 2.0f * d05 + 2.0f * d06;
+ float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05 - d06 + d07;
#if defined(HAS_BIAS)
// Add bias
Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
- DATA_TYPE b = (DATA_TYPE) * ((__global DATA_TYPE *)(vector_offset(&bias, x_out)));
+ float b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, x_out)));
- out00 += (DATA_TYPE)b;
- out01 += (DATA_TYPE)b;
- out02 += (DATA_TYPE)b;
- out03 += (DATA_TYPE)b;
+ out00 += (float)b;
+ out01 += (float)b;
+ out02 += (float)b;
+ out03 += (float)b;
#endif // defined(HAS_BIAS)
// Store the output tile
@@ -960,18 +960,18 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
#endif /* defined(SRC_DEPTH) */
offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, (int4)dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
- *(__global DATA_TYPE *)(dst_ptr + offset.s0) = out00;
- *(__global DATA_TYPE *)(dst_ptr + offset.s1) = out01;
- *(__global DATA_TYPE *)(dst_ptr + offset.s2) = out02;
- *(__global DATA_TYPE *)(dst_ptr + offset.s3) = out03;
+ *(__global DATA_TYPE *)(dst_ptr + offset.s0) = (DATA_TYPE)out00;
+ *(__global DATA_TYPE *)(dst_ptr + offset.s1) = (DATA_TYPE)out01;
+ *(__global DATA_TYPE *)(dst_ptr + offset.s2) = (DATA_TYPE)out02;
+ *(__global DATA_TYPE *)(dst_ptr + offset.s3) = (DATA_TYPE)out03;
#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
// Get output address
int offset = dst_offset_first_element_in_bytes + x_out * sizeof(DATA_TYPE) + y_out * dst_stride_y + z_out * dst_stride_z;
- *(__global DATA_TYPE *)(dst_ptr + 0 * dst_stride_y + offset) = out00;
- *(__global DATA_TYPE *)(dst_ptr + 1 * dst_stride_y + offset) = out01;
- *(__global DATA_TYPE *)(dst_ptr + 2 * dst_stride_y + offset) = out02;
- *(__global DATA_TYPE *)(dst_ptr + 3 * dst_stride_y + offset) = out03;
+ *(__global DATA_TYPE *)(dst_ptr + 0 * dst_stride_y + offset) = (DATA_TYPE)out00;
+ *(__global DATA_TYPE *)(dst_ptr + 1 * dst_stride_y + offset) = (DATA_TYPE)out01;
+ *(__global DATA_TYPE *)(dst_ptr + 2 * dst_stride_y + offset) = (DATA_TYPE)out02;
+ *(__global DATA_TYPE *)(dst_ptr + 3 * dst_stride_y + offset) = (DATA_TYPE)out03;
#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
@@ -1040,9 +1040,9 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
DATA_TYPE d77 = *((__global DATA_TYPE *)(src_addr + 63 * src_stride_z));
// Compute the 8x4 intermediate tensor
- VEC_DATA_TYPE(DATA_TYPE, 4)
+ VEC_DATA_TYPE(float, 4)
comm_fact0, comm_fact1, comm_fact2;
- VEC_DATA_TYPE(DATA_TYPE, 4)
+ VEC_DATA_TYPE(float, 4)
tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
@@ -1059,30 +1059,30 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
comm_fact1 = tmp_col3 + tmp_col4;
comm_fact2 = tmp_col5 + tmp_col6;
- VEC_DATA_TYPE(DATA_TYPE, 4)
- out_col0 = comm_fact0 + comm_fact1 + (DATA_TYPE)8.f * comm_fact2 + tmp_col0;
- VEC_DATA_TYPE(DATA_TYPE, 4)
- out_col2 = comm_fact0 + (DATA_TYPE)4.f * comm_fact1 + (DATA_TYPE)2.f * comm_fact2;
+ VEC_DATA_TYPE(float, 4)
+ out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
+ VEC_DATA_TYPE(float, 4)
+ out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
comm_fact0 = tmp_col1 - tmp_col2;
comm_fact1 = tmp_col3 - tmp_col4;
comm_fact2 = tmp_col5 - tmp_col6;
- VEC_DATA_TYPE(DATA_TYPE, 4)
- out_col1 = comm_fact0 + (DATA_TYPE)2.f * comm_fact1 + (DATA_TYPE)4.f * comm_fact2;
- VEC_DATA_TYPE(DATA_TYPE, 4)
- out_col3 = comm_fact0 + (DATA_TYPE)8.f * comm_fact1 + comm_fact2 + tmp_col7;
+ VEC_DATA_TYPE(float, 4)
+ out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
+ VEC_DATA_TYPE(float, 4)
+ out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
#if defined(HAS_BIAS)
// Add bias
Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
- DATA_TYPE b = (DATA_TYPE) * ((__global DATA_TYPE *)(vector_offset(&bias, x_out)));
+ DATA_TYPE b = (float) * ((__global DATA_TYPE *)(vector_offset(&bias, x_out)));
- out_col0 += (VEC_DATA_TYPE(DATA_TYPE, 4))b;
- out_col1 += (VEC_DATA_TYPE(DATA_TYPE, 4))b;
- out_col2 += (VEC_DATA_TYPE(DATA_TYPE, 4))b;
- out_col3 += (VEC_DATA_TYPE(DATA_TYPE, 4))b;
+ out_col0 += (VEC_DATA_TYPE(float, 4))b;
+ out_col1 += (VEC_DATA_TYPE(float, 4))b;
+ out_col2 += (VEC_DATA_TYPE(float, 4))b;
+ out_col3 += (VEC_DATA_TYPE(float, 4))b;
#endif // defined(HAS_BIAS)
// Get output address
#if defined(SRC_DEPTH)
@@ -1094,22 +1094,22 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
int4 mult_y = min((int4)dst_size - offset, (int4)1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise.
// Store the output tile
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 0 * (int)dst_stride_y + offset.s0) = out_col0.s0;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 1 * (int)dst_stride_y + offset.s0) = out_col1.s0;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 2 * (int)dst_stride_y + offset.s0) = out_col2.s0;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 3 * (int)dst_stride_y + offset.s0) = out_col3.s0;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 0 * (int)dst_stride_y + offset.s1) = out_col0.s1;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 1 * (int)dst_stride_y + offset.s1) = out_col1.s1;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 2 * (int)dst_stride_y + offset.s1) = out_col2.s1;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 3 * (int)dst_stride_y + offset.s1) = out_col3.s1;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 0 * (int)dst_stride_y + offset.s2) = out_col0.s2;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 1 * (int)dst_stride_y + offset.s2) = out_col1.s2;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 2 * (int)dst_stride_y + offset.s2) = out_col2.s2;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 3 * (int)dst_stride_y + offset.s2) = out_col3.s2;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 0 * (int)dst_stride_y + offset.s3) = out_col0.s3;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 1 * (int)dst_stride_y + offset.s3) = out_col1.s3;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 2 * (int)dst_stride_y + offset.s3) = out_col2.s3;
- *(__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 3 * (int)dst_stride_y + offset.s3) = out_col3.s3;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 0 * (int)dst_stride_y + offset.s0) = (DATA_TYPE)out_col0.s0;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 1 * (int)dst_stride_y + offset.s0) = (DATA_TYPE)out_col1.s0;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 2 * (int)dst_stride_y + offset.s0) = (DATA_TYPE)out_col2.s0;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s0 * 3 * (int)dst_stride_y + offset.s0) = (DATA_TYPE)out_col3.s0;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 0 * (int)dst_stride_y + offset.s1) = (DATA_TYPE)out_col0.s1;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 1 * (int)dst_stride_y + offset.s1) = (DATA_TYPE)out_col1.s1;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 2 * (int)dst_stride_y + offset.s1) = (DATA_TYPE)out_col2.s1;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s1 * 3 * (int)dst_stride_y + offset.s1) = (DATA_TYPE)out_col3.s1;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 0 * (int)dst_stride_y + offset.s2) = (DATA_TYPE)out_col0.s2;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 1 * (int)dst_stride_y + offset.s2) = (DATA_TYPE)out_col1.s2;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 2 * (int)dst_stride_y + offset.s2) = (DATA_TYPE)out_col2.s2;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s2 * 3 * (int)dst_stride_y + offset.s2) = (DATA_TYPE)out_col3.s2;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 0 * (int)dst_stride_y + offset.s3) = (DATA_TYPE)out_col0.s3;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 1 * (int)dst_stride_y + offset.s3) = (DATA_TYPE)out_col1.s3;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 2 * (int)dst_stride_y + offset.s3) = (DATA_TYPE)out_col2.s3;
+ *(__global DATA_TYPE *)(dst_ptr + mult_y.s3 * 3 * (int)dst_stride_y + offset.s3) = (DATA_TYPE)out_col3.s3;
#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
}
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 5e02dda9e3..b549638343 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -47,12 +47,14 @@ namespace
{
using ElementsProcessed = Steps;
-inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
+ bool fp_mixed_precision)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((fp_mixed_precision && (input0->data_type() != DataType::F16)), "Mixed precision floating point is supported only for F16 data");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
@@ -216,12 +218,13 @@ CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
{
}
-void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info,
+ bool fp_mixed_precision)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
// Perform validate step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, fp_mixed_precision));
_input0 = input0;
_input1 = input1;
@@ -316,6 +319,11 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// The work-group size equal to the Bifrost quad size has been proved to be optimal for these kernels
// via exhaustive autotuning over a range of representative layer configurations.
set_lws_hint(cl::NDRange(4));
+ if(fp_mixed_precision && data_type == DataType::F16)
+ {
+ // currently wider accumulator is only supported for fp16 kernels.
+ kernel_name += "_acc32";
+ }
}
else // (MIDGARD and F32) or (F16)
{
@@ -331,6 +339,7 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
// Set config_id for enabling LWS tuning
_config_id = "gemm_";
_config_id += (is_interleaved_transposed ? "reshaped_" : "");
+ _config_id += (fp_mixed_precision ? "fp_mixed_" : "");
_config_id += (_reinterpret_input_as_3d ? "3di_" : "");
_config_id += (_reinterpret_output_as_3d ? "3do_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
@@ -347,12 +356,12 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
}
Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved_transposed,
- const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target)
+ const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision)
{
// Note: num_elements_processed will be set in validate_and_configure_window()
ElementsProcessed num_elements_processed{};
ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, is_interleaved_transposed, reshape_info, fp_mixed_precision));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(),
input1->clone().get(),
output->clone().get(),
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 6adbdc0cb6..baa0cf46dc 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -155,7 +155,8 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *
// Configure and tune matrix multiply kernel
_mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
mult_transpose1xW_width, mult_interleave4x4_height,
- depth_output_gemm3d, reinterpret_input_as_3d));
+ depth_output_gemm3d, reinterpret_input_as_3d),
+ gemm_info.fp_mixed_precision());
CLScheduler::get().tune_kernel_static(_mm_kernel);
if(_is_interleaved_transposed)
@@ -236,7 +237,7 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
}
// Validate matrix multiply
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, alpha, run_interleave_transpose, reshape_info, gpu_target));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, alpha, run_interleave_transpose, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
if(beta != 0 && c != nullptr)
{
diff --git a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
index 70bf3ae593..1abcb67132 100644
--- a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
@@ -104,9 +104,9 @@ void CLWinogradConvolutionLayer::configure(ICLTensor *input, const ICLTensor *we
// Check if the Winograd configuration requires fast math
if(!enable_fast_math)
{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); //disable winograd for fp16 if fast math is false.
ARM_COMPUTE_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size), "This Winograd configuration requires enable_fast_math=true");
}
-
const WinogradInfo winograd_info = WinogradInfo(output_tile,
kernel_size,
input_dims,
@@ -129,7 +129,8 @@ void CLWinogradConvolutionLayer::configure(ICLTensor *input, const ICLTensor *we
_filter_transform.configure(weights, &_input1, winograd_info);
// Configure batched matrix multiply
- _batched_mm.configure(&_input0, &_input1, nullptr, &_batched_mm_output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/));
+ _batched_mm.configure(&_input0, &_input1, nullptr, &_batched_mm_output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, 0, false, false, GEMMLowpOutputStageInfo(),
+ (input->info()->data_type() == DataType::F16)));
// Configure output transform
_output_transform.configure(&_batched_mm_output, biases, output, winograd_info);
@@ -158,13 +159,10 @@ Status CLWinogradConvolutionLayer::validate(const ITensorInfo *input, const ITen
const Size2D kernel_size = Size2D(weights->tensor_shape()[idx_width], weights->tensor_shape()[idx_height]);
const Size2D output_tile = winograd_output_tile(input_dims, kernel_size, input->data_layout());
- //FP16 implementation of winograd is slower than direct convolution.
- //The following check needs to be removed when fp16 winograd is faster than direct convolution (COMPMID-1266)
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
-
// Check if the Winograd configuration requires fast math
if(!enable_fast_math)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); //disable winograd for fp16 if fast math is false.
ARM_COMPUTE_RETURN_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size), "This Winograd configuration requires enable_fast_math=true");
}
@@ -188,7 +186,8 @@ Status CLWinogradConvolutionLayer::validate(const ITensorInfo *input, const ITen
TensorShape batched_mm_output_shape = input0.tensor_shape();
batched_mm_output_shape[0] = input1.tensor_shape()[0];
const TensorInfo batched_mm_output = input0.clone()->set_tensor_shape(batched_mm_output_shape);
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(&input0, &input1, nullptr, &batched_mm_output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/)));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(&input0, &input1, nullptr, &batched_mm_output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, 0, false, false,
+ GEMMLowpOutputStageInfo(), (input->data_type() == DataType::F16))));
// Configure output transform
ARM_COMPUTE_RETURN_ON_ERROR(CLWinogradOutputTransformKernel::validate(&batched_mm_output, biases, output, winograd_info));
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index 335ef9130a..dd4a8bee2c 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -220,6 +220,45 @@ protected:
DataLayout _data_layout{ DataLayout::UNKNOWN };
};
+template <typename T1, typename T2>
+SimpleTensor<T1> copy_tensor(const SimpleTensor<T2> &tensor)
+{
+ SimpleTensor<T1> st(tensor.shape(), tensor.data_type(),
+ tensor.num_channels(),
+ tensor.quantization_info(),
+ tensor.data_layout());
+ for(size_t n = 0; n < size_t(st.num_elements()); n++)
+ {
+ st.data()[n] = static_cast<T1>(tensor.data()[n]);
+ }
+ return st;
+}
+
+template <typename T1, typename T2, typename std::enable_if<std::is_same<T1, T2>::value, int>::type = 0>
+SimpleTensor<T1> copy_tensor(const SimpleTensor<half> &tensor)
+{
+ SimpleTensor<T1> st(tensor.shape(), tensor.data_type(),
+ tensor.num_channels(),
+ tensor.quantization_info(),
+ tensor.data_layout());
+ memcpy((void *)st.data(), (const void *)tensor.data(), size_t(st.num_elements() * sizeof(T1)));
+ return st;
+}
+
+template < typename T1, typename T2, typename std::enable_if < (std::is_same<T1, half>::value || std::is_same<T2, half>::value), int >::type = 0 >
+SimpleTensor<T1> copy_tensor(const SimpleTensor<half> &tensor)
+{
+ SimpleTensor<T1> st(tensor.shape(), tensor.data_type(),
+ tensor.num_channels(),
+ tensor.quantization_info(),
+ tensor.data_layout());
+ for(size_t n = 0; n < size_t(st.num_elements()); n++)
+ {
+ st.data()[n] = half_float::detail::half_cast<T1, T2>(tensor.data()[n]);
+ }
+ return st;
+}
+
template <typename T>
SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format)
: _buffer(nullptr),
diff --git a/tests/validation/CL/Winograd.cpp b/tests/validation/CL/Winograd.cpp
index 930f7aa8ce..f7f06b7f79 100644
--- a/tests/validation/CL/Winograd.cpp
+++ b/tests/validation/CL/Winograd.cpp
@@ -58,6 +58,9 @@ constexpr AbsoluteTolerance<float> tolerance_f32(0.001f);
const AbsoluteTolerance<half> tolerance_f16(half(0.5f));
constexpr AbsoluteTolerance<float> tolerance_convolution_layer_f32(0.1f);
const AbsoluteTolerance<half> tolerance_convolution_layer_f16(half(0.4f));
+RelativeTolerance<half_float::half> rel_tolerance_f16(half(0.2)); /**< Tolerance value for comparing reference's output against implementation's output for FP16 data types */
+constexpr float tolerance_num = 0.05f; /**< Tolerance number */
+constexpr float abs_tolerance_convolution_layer_f16 = 2.5f; /**< Tolerance number */
// Input transform
const auto SmallWinogradInputTransformDatasetNCHW =
@@ -834,10 +837,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture, fram
TEST_SUITE_END() // Conv1x5
TEST_SUITE_END() // FP32
-#ifdef WINOGRAD_F16_SUPPORT //to be reintroduced after COMPMID-1266 is resolved
+
TEST_SUITE(FP16)
-using CLWinogradConvolutionLayerFastMathFixture16 = WinogradConvolutionLayerFastMathValidationFixture<CLTensor, CLAccessor, CLWinogradConvolutionLayer, half>;
+using CLWinogradConvolutionLayerFastMathFixture16 = WinogradConvolutionLayerFastMathValidationFixture<CLTensor, CLAccessor, CLWinogradConvolutionLayer, half, float>;
TEST_SUITE(Conv3x3)
FIXTURE_DATA_TEST_CASE(RunSmall, CLWinogradConvolutionLayerFastMathFixture16, framework::DatasetMode::PRECOMMIT,
combine(combine(combine(datasets::SmallWinogradConvolutionLayer3x3Dataset(),
@@ -856,7 +859,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture16, fr
framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
- validate(CLAccessor(_target), _reference, tolerance_convolution_layer_f16);
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_convolution_layer_f16);
}
TEST_SUITE_END() // Conv3x3
@@ -878,7 +881,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture16, fr
framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
- validate(CLAccessor(_target), _reference, tolerance_convolution_layer_f16);
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_convolution_layer_f16);
}
TEST_SUITE_END() // Conv3x1
@@ -900,7 +903,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture16, fr
framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
{
// Validate output
- validate(CLAccessor(_target), _reference, tolerance_convolution_layer_f16);
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_convolution_layer_f16);
}
TEST_SUITE_END() // Conv1x3
@@ -924,7 +927,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture16, fr
{
// Validate output
- validate(CLAccessor(_target), _reference, tolerance_convolution_layer_f16);
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_convolution_layer_f16);
}
TEST_SUITE_END() // Conv5x5
@@ -948,7 +951,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture16, fr
{
// Validate output
- validate(CLAccessor(_target), _reference, tolerance_convolution_layer_f16);
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_convolution_layer_f16);
}
TEST_SUITE_END() // Conv5x1
@@ -972,12 +975,12 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture16, fr
{
// Validate output
- validate(CLAccessor(_target), _reference, tolerance_convolution_layer_f16);
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_convolution_layer_f16);
}
TEST_SUITE_END() // Conv1x5
TEST_SUITE_END() // FP16
-#endif /*#ifdef WINOGRAD_F16_SUPPORT*/
+
TEST_SUITE_END() // ConvolutionLayer
TEST_SUITE_END() // Winograd
TEST_SUITE_END() // CL
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index d216d9db86..d8710ee91b 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -120,7 +120,7 @@ template <typename T>
using NEWinogradConvolutionLayerFixture = WinogradConvolutionLayerFastMathValidationFixture<Tensor, Accessor, NEWinogradConvolutionLayer, T>;
template <typename T>
-using NEWinogradConvolutionLayerNoBiasFixture = WinogradConvolutionLayerFastMathValidationFixture<Tensor, Accessor, NEWinogradConvolutionLayer, T, false>;
+using NEWinogradConvolutionLayerNoBiasFixture = WinogradConvolutionLayerFastMathValidationFixture<Tensor, Accessor, NEWinogradConvolutionLayer, T, T, false>;
TEST_SUITE(FP32)
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
index 15ce201222..9c9e634205 100644
--- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
@@ -39,6 +39,7 @@
#include "tests/validation/reference/Permute.h"
#include "tests/validation/reference/Utils.h"
#include "tests/validation/reference/Winograd.h"
+#include "utils/Utils.h"
#include <random>
@@ -156,7 +157,7 @@ protected:
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool use_bias = true>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename T1 = T, bool use_bias = true>
class WinogradConvolutionLayerFastMathValidationFixture : public framework::Fixture
{
public:
@@ -177,6 +178,11 @@ protected:
switch(tensor.data_type())
{
case DataType::F16:
+ {
+ arm_compute::utils::uniform_real_distribution_fp16 distribution((half)min, (half)max);
+ library->fill(tensor, distribution, i);
+ break;
+ }
case DataType::F32:
{
std::uniform_real_distribution<> distribution(min, max);
@@ -245,21 +251,25 @@ protected:
DataType data_type, ActivationLayerInfo act_info)
{
// Create reference
- SimpleTensor<T> src{ input_shape, data_type, 1 };
- SimpleTensor<T> weights{ weights_shape, data_type, 1 };
- SimpleTensor<T> bias{ bias_shape, data_type, 1 };
+ SimpleTensor<T> src_t{ input_shape, data_type, 1 };
+ SimpleTensor<T> weights_t{ weights_shape, data_type, 1 };
+ SimpleTensor<T> bias_t{ bias_shape, data_type, 1 };
// Fill reference
- fill(src, 0, -1.f, 1.f);
- fill(weights, 1, -1.f, 1.f);
+ fill(src_t, 0, -1.f, 1.f);
+ SimpleTensor<T1> src_t1(copy_tensor<T1, T>(src_t));
+
+ fill(weights_t, 1, -1.f, 1.f);
+ SimpleTensor<T1> weights_t1(copy_tensor<T1, T>(weights_t));
if(use_bias)
{
- fill(bias, 2, -1.f, 1.f);
+ fill(bias_t, 2, -1.f, 1.f);
}
else
{
- fill(bias, 2, 0.f, 0.f);
+ fill(bias_t, 2, 0.f, 0.f);
}
+ SimpleTensor<T1> bias_t1(copy_tensor<T1, T>(bias_t));
// Set output tile
Size2D output_tile(4U, 4U);
@@ -286,7 +296,7 @@ protected:
Size2D(weights_shape[0], weights_shape[1]),
Size2D(input_shape[0], input_shape[1]),
info,
- src.data_layout());
+ src_t1.data_layout());
// Compute tensor shapes for input, filter and output transforms
TensorShape input_transform_shape = compute_winograd_input_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
@@ -296,15 +306,16 @@ protected:
TensorShape output_transform_shape = compute_winograd_output_transform_shape(TensorInfo(batched_gemm_shape, 1, data_type), winograd_info);
// Dummy matrix C to perform matrix multiplication
- SimpleTensor<T> dummy_c{ batched_gemm_shape, data_type, 1 };
+ SimpleTensor<T1> dummy_c{ batched_gemm_shape, data_type, 1 };
// Compute Winograd-based convolution
- SimpleTensor<T> input_transform_out = reference::winograd_input_transform<T>(src, input_transform_shape, winograd_info);
- SimpleTensor<T> filter_transform_out = reference::winograd_filter_transform<T>(weights, filter_transform_shape, winograd_info);
- SimpleTensor<T> batched_gemm = reference::gemm<T>(input_transform_out, filter_transform_out, dummy_c, 1.0f, 0.0f);
- SimpleTensor<T> conv_out = reference::winograd_output_transform<T>(batched_gemm, bias, output_transform_shape, winograd_info);
+ SimpleTensor<T1> input_transform_out = reference::winograd_input_transform<T1>(src_t1, input_transform_shape, winograd_info);
- return (act_info.enabled()) ? reference::activation_layer<T>(conv_out, act_info) : conv_out;
+ SimpleTensor<T1> filter_transform_out = reference::winograd_filter_transform<T1>(weights_t1, filter_transform_shape, winograd_info);
+ SimpleTensor<T1> batched_gemm = reference::gemm<T1>(input_transform_out, filter_transform_out, dummy_c, 1.0f, 0.0f);
+ SimpleTensor<T1> conv_out = reference::winograd_output_transform<T1>(batched_gemm, bias_t1, output_transform_shape, winograd_info);
+ SimpleTensor<T> conv_out_t(std::move(copy_tensor<T, T1>(conv_out)));
+ return (act_info.enabled()) ? reference::activation_layer<T>(conv_out_t, act_info) : conv_out_t;
}
TensorType _target{};
diff --git a/utils/Utils.h b/utils/Utils.h
index 130e1f72fe..92ab1a30b9 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -181,6 +181,8 @@ inline std::string get_typestring(DataType data_type)
return endianness + "u" + support::cpp11::to_string(sizeof(uint64_t));
case DataType::S64:
return endianness + "i" + support::cpp11::to_string(sizeof(int64_t));
+ case DataType::F16:
+ return endianness + "f" + support::cpp11::to_string(sizeof(half));
case DataType::F32:
return endianness + "f" + support::cpp11::to_string(sizeof(float));
case DataType::F64:
@@ -275,6 +277,43 @@ inline void unmap(GCTensor &tensor)
}
#endif /* ARM_COMPUTE_GC */
+/** Specialized class to generate random non-zero FP16 values.
+ * uniform_real_distribution<half> generates values that get rounded off to zero, causing
+ * differences between ACL and reference implementation
+*/
+class uniform_real_distribution_fp16
+{
+ half min{ 0.0f }, max{ 0.0f };
+ std::uniform_real_distribution<float> neg{ min, -0.3f };
+ std::uniform_real_distribution<float> pos{ 0.3f, max };
+ std::uniform_int_distribution<uint8_t> sign_picker{ 0, 1 };
+
+public:
+ using result_type = half;
+ /** Constructor
+ *
+ * @param[in] a Minimum value of the distribution
+ * @param[in] b Maximum value of the distribution
+ */
+ explicit uniform_real_distribution_fp16(half a = half(0.0), half b = half(1.0))
+ : min(a), max(b)
+ {
+ }
+
+ /** () operator to generate next value
+ *
+ * @param[in] gen an uniform random bit generator object
+ */
+ half operator()(std::mt19937 &gen)
+ {
+ if(sign_picker(gen))
+ {
+ return (half)neg(gen);
+ }
+ return (half)pos(gen);
+ }
+};
+
/** Numpy data loader */
class NPYLoader
{
@@ -416,6 +455,7 @@ public:
case arm_compute::DataType::QASYMM8:
case arm_compute::DataType::S32:
case arm_compute::DataType::F32:
+ case arm_compute::DataType::F16:
{
// Read data
if(!are_layouts_different && !_fortran_order && tensor.info()->padding().empty())
@@ -699,6 +739,18 @@ void fill_random_tensor(T &tensor, float lower_bound, float upper_bound)
switch(tensor.info()->data_type())
{
+ case arm_compute::DataType::F16:
+ {
+ std::uniform_real_distribution<float> dist(lower_bound, upper_bound);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ *reinterpret_cast<half *>(it.ptr()) = (half)dist(gen);
+ },
+ it);
+
+ break;
+ }
case arm_compute::DataType::F32:
{
std::uniform_real_distribution<float> dist(lower_bound, upper_bound);