diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2020-09-15 14:17:41 +0100 |
---|---|---|
committer | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2020-09-18 09:53:42 +0000 |
commit | 6f9313477f6a158210479996523c210452d4f07a (patch) | |
tree | a64f8b97f3e95b6d084955cf675fa5e6d19205a7 /src | |
parent | 82c1a1fc63d6a49c0b4be39529412c7f7bc8ea64 (diff) | |
download | ComputeLibrary-6f9313477f6a158210479996523c210452d4f07a.tar.gz |
COMPMID-3671: Extend cl image support to f16 in CLGEMMMatrixMultiplyReshapedKernel
Resolves: COMPMID-3671, COMPMID-3672
- Extend cl image support to f16 in CLGEMMMatrixMultiplyReshapedKernel
- Extend cl image support to f16 in CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
- Change the interface of create_image2d_from_buffer
- Extend test
Change-Id: I27363be71fa515fbf71aa4be5ed0d6c730f38f34
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3992
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/core/CL/CLUtils.cpp | 18 | ||||
-rw-r--r-- | src/core/CL/CLUtils.h | 4 | ||||
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp | 2 | ||||
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 2 |
4 files changed, 20 insertions, 6 deletions
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp index 5d0cdf7f46..67af240044 100644 --- a/src/core/CL/CLUtils.cpp +++ b/src/core/CL/CLUtils.cpp @@ -26,12 +26,26 @@ #include "src/core/CL/CLUtils.h" -cl::Image2D arm_compute::create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, cl_channel_type data_type, size_t image_row_pitch) +cl::Image2D arm_compute::create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, DataType data_type, size_t image_row_pitch) { + cl_channel_type cl_data_type; + + switch(data_type) + { + case DataType::F32: + cl_data_type = CL_FLOAT; + break; + case DataType::F16: + cl_data_type = CL_HALF_FLOAT; + break; + default: + ARM_COMPUTE_ERROR("Data type not support with OpenCL image2d"); + } + cl_mem cl_image; cl_int err = CL_SUCCESS; - const cl_image_format format = { CL_RGBA, data_type }; + const cl_image_format format = { CL_RGBA, cl_data_type }; cl_image_desc desc; memset(&desc, 0, sizeof(desc)); diff --git a/src/core/CL/CLUtils.h b/src/core/CL/CLUtils.h index 8f1c58bcba..b65d547756 100644 --- a/src/core/CL/CLUtils.h +++ b/src/core/CL/CLUtils.h @@ -44,12 +44,12 @@ class TensorShape; * @param[in] ctx cl::Context object * @param[in] buffer cl::Buffer object from which the OpenCL image2d object is created * @param[in] shape2d 2D tensor shape - * @param[in] data_type cl_channel_type to use. Only supported CL_FLOAT + * @param[in] data_type DataType to use. Only supported: F32,F16 * @param[in] image_row_pitch Image row pitch (a.k.a. stride Y) to be used in the image2d object * * @return cl::Image2D object */ -cl::Image2D create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, cl_channel_type data_type, size_t image_row_pitch); +cl::Image2D create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, DataType data_type, size_t image_row_pitch); } // arm_compute diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp index 8f20de1ea1..b0f0e8a81f 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp @@ -376,7 +376,7 @@ void CLGEMMMatrixMultiplyReshapedKernel::run(const Window &window, cl::CommandQu const TensorShape shape2d(_input1->info()->dimension(0) / 4, _input1->info()->dimension(1) * _input1->info()->dimension(2)); const size_t image_row_pitch = _input1->info()->strides_in_bytes()[1]; - input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, CL_FLOAT, image_row_pitch); + input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, _input1->info()->data_type(), image_row_pitch); } do diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index cf77c70bfa..0ae30ed30e 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -378,7 +378,7 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::run(const Window &window, cl::Co const TensorShape shape2d(_input1->info()->dimension(0) / 4, _input1->info()->dimension(1) * _input1->info()->dimension(2)); const size_t image_row_pitch = _input1->info()->strides_in_bytes()[1]; - input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, CL_FLOAT, image_row_pitch); + input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, _input1->info()->data_type(), image_row_pitch); } do |