From e03342e3ba78ecf5b9128339dd47c30e00cb8565 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Mon, 15 Jan 2018 14:39:13 +0000 Subject: COMPMID-799 - Use new OpenCL 8-bit dot product instruction Change-Id: I03d6c6db13bcb565f117725bdab2b68c89a49e21 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122185 Reviewed-by: Anthony Barbier Tested-by: Jenkins Reviewed-by: Gian Marco Iodice --- src/core/CL/CLHelpers.cpp | 5 ++ src/core/CL/CLKernelLibrary.cpp | 5 ++ src/core/CL/cl_kernels/gemmlowp.cl | 98 ++++++++++++++++++++++++++++++++++++++ src/core/CL/cl_kernels/helpers.h | 4 ++ 4 files changed, 112 insertions(+) (limited to 'src') diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp index cda29d69d1..23c24c0337 100644 --- a/src/core/CL/CLHelpers.cpp +++ b/src/core/CL/CLHelpers.cpp @@ -129,6 +129,11 @@ bool fp16_supported(const cl::Device &device) return device_supports_extension(device, "cl_khr_fp16"); } +bool dot8_supported(const cl::Device &device) +{ + return device_supports_extension(device, "cl_arm_integer_dot_product_int8"); +} + CLVersion get_cl_version(const cl::Device &device) { std::string version_str = device.getInfo(); diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp index cdde7ef75a..b4531b841b 100644 --- a/src/core/CL/CLKernelLibrary.cpp +++ b/src/core/CL/CLKernelLibrary.cpp @@ -753,6 +753,11 @@ Kernel CLKernelLibrary::create_kernel(const std::string &kernel_name, const Stri concat_str += " -DARM_COMPUTE_OPENCL_FP16_ENABLED=1 "; } + if(dot8_supported(_device)) + { + concat_str += " -DARM_COMPUTE_OPENCL_DOT8_ENABLED=1 "; + } + if(get_cl_version(_device) == CLVersion::CL20) { concat_str += " -cl-std=CL2.0 "; diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl index 5e144d73af..da915778e7 100644 --- a/src/core/CL/cl_kernels/gemmlowp.cl +++ b/src/core/CL/cl_kernels/gemmlowp.cl @@ -190,6 +190,63 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0) #if MULT_INTERLEAVE4X4_HEIGHT == 1 for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP)) { +#if ARM_COMPUTE_OPENCL_DOT8_ENABLED + // Load values from matrix A (interleaved) and matrix B (transposed) + uchar16 a0 = vload16(0, src_addr_a); + uchar4 b0 = vload4(0, src_addr_b); + uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP); + uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP); + uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP); + + // Accumulate + c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); + c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); + c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); + c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); + + c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); + c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); + c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); + c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); + + c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); + c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); + c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); + c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); + + c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); + c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); + c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); + c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); + + // Load values from matrix A (interleaved) and matrix B (transposed) + a0 = vload16(0, src_addr_a + 16); + b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP); + b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP); + b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP); + b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP); + + // Accumulate + c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); + c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); + c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); + c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); + + c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); + c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); + c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); + c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); + + c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); + c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); + c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); + c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); + + c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); + c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); + c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); + c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); +#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Load values from matrix A (interleaved) and matrix B (transposed) uchar16 a0 = vload16(0, src_addr_a); uchar4 b0 = vload4(0, src_addr_b); @@ -375,6 +432,7 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0) c31 += (ushort)a0.sF * b0.s1; c32 += (ushort)a0.sF * b0.s2; c33 += (ushort)a0.sF * b0.s3; +#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // MULT_INTERLEAVE4X4_HEIGHT == 1 @@ -666,6 +724,13 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); { +#if ARM_COMPUTE_OPENCL_DOT8_ENABLED + // Accumulate + acc00 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0); + acc01 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0); + acc02 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0); + acc03 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0); +#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0; @@ -691,9 +756,17 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); +#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 { +#if ARM_COMPUTE_OPENCL_DOT8_ENABLED + // Accumulate + acc10 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1); + acc11 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1); + acc12 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1); + acc13 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1); +#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0; @@ -719,10 +792,18 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); +#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 { +#if ARM_COMPUTE_OPENCL_DOT8_ENABLED + // Accumulate + acc20 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2); + acc21 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2); + acc22 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2); + acc23 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2); +#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0; @@ -748,10 +829,18 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); +#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 { +#if ARM_COMPUTE_OPENCL_DOT8_ENABLED + // Accumulate + acc30 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3); + acc31 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3); + acc32 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3); + acc33 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3); +#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0; @@ -777,10 +866,18 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); +#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 { +#if ARM_COMPUTE_OPENCL_DOT8_ENABLED + // Accumulate + acc40 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4); + acc41 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4); + acc42 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4); + acc43 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4); +#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0; @@ -806,6 +903,7 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); +#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 } diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h index 615c5188a1..f51eccb6d4 100644 --- a/src/core/CL/cl_kernels/helpers.h +++ b/src/core/CL/cl_kernels/helpers.h @@ -28,6 +28,10 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) +#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) +#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable +#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) + #if defined(ARM_COMPUTE_DEBUG_ENABLED) #if defined(cl_arm_printf) #pragma OPENCL EXTENSION cl_arm_printf : enable -- cgit v1.2.1