aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/CL/CLHelpers.h7
-rw-r--r--src/core/CL/CLHelpers.cpp26
-rw-r--r--src/core/CL/CLKernelLibrary.cpp5
-rw-r--r--src/core/CL/cl_kernels/gemm.cl4
-rw-r--r--src/core/CL/cl_kernels/helpers.h3
5 files changed, 30 insertions, 15 deletions
diff --git a/arm_compute/core/CL/CLHelpers.h b/arm_compute/core/CL/CLHelpers.h
index 1a4476e304..b93bae8d82 100644
--- a/arm_compute/core/CL/CLHelpers.h
+++ b/arm_compute/core/CL/CLHelpers.h
@@ -126,6 +126,13 @@ GPUTarget get_arch_from_target(GPUTarget target);
* @return the highest OpenCL version supported
*/
CLVersion get_cl_version(const cl::Device &device);
+/** Helper function to check whether the cl_khr_fp16 extension is supported
+ *
+ * @param[in] device A CL device
+ *
+ * @return True if the extension is supported
+ */
+bool fp16_support(const cl::Device &device);
/** Helper function to check whether the arm_non_uniform_work_group_size extension is supported
*
* @param[in] device A CL device
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp
index 09ec329e4c..901ac3f39a 100644
--- a/src/core/CL/CLHelpers.cpp
+++ b/src/core/CL/CLHelpers.cpp
@@ -58,6 +58,13 @@ arm_compute::GPUTarget get_midgard_target(const std::string &version)
return arm_compute::GPUTarget::MIDGARD;
}
}
+
+bool extension_support(const cl::Device &device, const char *extension_name)
+{
+ std::string extensions = device.getInfo<CL_DEVICE_EXTENSIONS>();
+ auto pos = extensions.find(extension_name);
+ return (pos != std::string::npos);
+}
} // namespace
namespace arm_compute
@@ -206,21 +213,12 @@ GPUTarget get_arch_from_target(GPUTarget target)
bool non_uniform_workgroup_support(const cl::Device &device)
{
- std::vector<char> extension;
- size_t extension_size = 0;
- cl_int err = clGetDeviceInfo(device.get(), CL_DEVICE_EXTENSIONS, 0, nullptr, &extension_size);
- ARM_COMPUTE_ERROR_ON_MSG((err != 0) || (extension_size == 0), "clGetDeviceInfo failed to return valid information");
- ARM_COMPUTE_UNUSED(err);
- // Resize vector
- extension.resize(extension_size);
- // Query extension
- err = clGetDeviceInfo(device.get(), CL_DEVICE_EXTENSIONS, extension_size, extension.data(), nullptr);
- ARM_COMPUTE_ERROR_ON_MSG(err != 0, "clGetDeviceInfo failed to return valid information");
- ARM_COMPUTE_UNUSED(err);
+ return extension_support(device, "cl_arm_non_uniform_work_group_size");
+}
- std::string extension_str(extension.begin(), extension.end());
- auto pos = extension_str.find("cl_arm_non_uniform_work_group_size");
- return (pos != std::string::npos);
+bool fp16_support(const cl::Device &device)
+{
+ return extension_support(device, "cl_khr_fp16");
}
CLVersion get_cl_version(const cl::Device &device)
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 62ef2593e7..9e2b5bd600 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -596,6 +596,11 @@ Kernel CLKernelLibrary::create_kernel(const std::string &kernel_name, const Stri
std::string concat_str;
+ if(fp16_support(_device))
+ {
+ concat_str += " -DARM_COMPUTE_OPENCL_FP16_ENABLED=1 ";
+ }
+
if(non_uniform_workgroup_support(_device))
{
concat_str += " -cl-arm-non-uniform-work-group-size ";
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 7f2a08bc2c..d08e821431 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -704,6 +704,7 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(offset(&dst, 0, 3)));
}
+#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
* Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
*
@@ -802,6 +803,7 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
vstore8(c20, 0, (__global half *)(offset(&dst, 0, 2)));
vstore8(c30, 0, (__global half *)(offset(&dst, 0, 3)));
}
+#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
#ifdef FIXED_POINT_POSITION
/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
@@ -1652,4 +1654,4 @@ __kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
}
-#endif /* WIDTH_VECTOR_A */ \ No newline at end of file
+#endif /* WIDTH_VECTOR_A */
diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h
index 4421e74816..330d67daa5 100644
--- a/src/core/CL/cl_kernels/helpers.h
+++ b/src/core/CL/cl_kernels/helpers.h
@@ -24,7 +24,10 @@
#ifndef ARM_COMPUTE_HELPER_H
#define ARM_COMPUTE_HELPER_H
+#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
+
#if defined(ARM_COMPUTE_DEBUG_ENABLED)
#pragma OPENCL EXTENSION cl_arm_printf : enable
#endif // defined(ARM_COMPUTE_DEBUG_ENABLED)