aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/CLUtils.cpp18
-rw-r--r--src/core/CL/CLUtils.h4
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp2
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp2
4 files changed, 20 insertions, 6 deletions
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp
index 5d0cdf7f46..67af240044 100644
--- a/src/core/CL/CLUtils.cpp
+++ b/src/core/CL/CLUtils.cpp
@@ -26,12 +26,26 @@
#include "src/core/CL/CLUtils.h"
-cl::Image2D arm_compute::create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, cl_channel_type data_type, size_t image_row_pitch)
+cl::Image2D arm_compute::create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, DataType data_type, size_t image_row_pitch)
{
+ cl_channel_type cl_data_type;
+
+ switch(data_type)
+ {
+ case DataType::F32:
+ cl_data_type = CL_FLOAT;
+ break;
+ case DataType::F16:
+ cl_data_type = CL_HALF_FLOAT;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not support with OpenCL image2d");
+ }
+
cl_mem cl_image;
cl_int err = CL_SUCCESS;
- const cl_image_format format = { CL_RGBA, data_type };
+ const cl_image_format format = { CL_RGBA, cl_data_type };
cl_image_desc desc;
memset(&desc, 0, sizeof(desc));
diff --git a/src/core/CL/CLUtils.h b/src/core/CL/CLUtils.h
index 8f1c58bcba..b65d547756 100644
--- a/src/core/CL/CLUtils.h
+++ b/src/core/CL/CLUtils.h
@@ -44,12 +44,12 @@ class TensorShape;
* @param[in] ctx cl::Context object
* @param[in] buffer cl::Buffer object from which the OpenCL image2d object is created
* @param[in] shape2d 2D tensor shape
- * @param[in] data_type cl_channel_type to use. Only supported CL_FLOAT
+ * @param[in] data_type DataType to use. Only supported: F32,F16
* @param[in] image_row_pitch Image row pitch (a.k.a. stride Y) to be used in the image2d object
*
* @return cl::Image2D object
*/
-cl::Image2D create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, cl_channel_type data_type, size_t image_row_pitch);
+cl::Image2D create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, DataType data_type, size_t image_row_pitch);
} // arm_compute
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index 8f20de1ea1..b0f0e8a81f 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -376,7 +376,7 @@ void CLGEMMMatrixMultiplyReshapedKernel::run(const Window &window, cl::CommandQu
const TensorShape shape2d(_input1->info()->dimension(0) / 4, _input1->info()->dimension(1) * _input1->info()->dimension(2));
const size_t image_row_pitch = _input1->info()->strides_in_bytes()[1];
- input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, CL_FLOAT, image_row_pitch);
+ input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, _input1->info()->data_type(), image_row_pitch);
}
do
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index cf77c70bfa..0ae30ed30e 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -378,7 +378,7 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::run(const Window &window, cl::Co
const TensorShape shape2d(_input1->info()->dimension(0) / 4, _input1->info()->dimension(1) * _input1->info()->dimension(2));
const size_t image_row_pitch = _input1->info()->strides_in_bytes()[1];
- input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, CL_FLOAT, image_row_pitch);
+ input1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), _input1->cl_buffer(), shape2d, _input1->info()->data_type(), image_row_pitch);
}
do