diff options
Diffstat (limited to 'src/core/CL/CLUtils.h')
-rw-r--r-- | src/core/CL/CLUtils.h | 47 |
1 files changed, 41 insertions, 6 deletions
diff --git a/src/core/CL/CLUtils.h b/src/core/CL/CLUtils.h index b65d547756..f9dcfeac3a 100644 --- a/src/core/CL/CLUtils.h +++ b/src/core/CL/CLUtils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2020-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,14 +22,36 @@ * SOFTWARE. */ -#ifndef ARM_COMPUTE_CL_CLUTILS_H -#define ARM_COMPUTE_CL_CLUTILS_H +#ifndef ACL_SRC_CORE_CL_CLUTILS_H +#define ACL_SRC_CORE_CL_CLUTILS_H #include "arm_compute/core/CL/OpenCL.h" +#include <map> + namespace arm_compute { class TensorShape; +class CLBuildOptions; +class ITensorInfo; +class ICLTensor; +enum class DataType; + +/** OpenCL Image2D types */ +enum class CLImage2DType +{ + ReadOnly, + WriteOnly +}; + +/** Create a cl::Image2D object from a tensor + * + * @param[in] tensor Tensor from which to construct Image 2D object + * @param[in] image_type Image 2D type (@ref CLImage2DType) + * + * @return cl::Image2D object + */ +cl::Image2D create_image2d_from_tensor(const ICLTensor *tensor, CLImage2DType image_type); /** Create a cl::Image2D object from an OpenCL buffer * @@ -46,11 +68,24 @@ class TensorShape; * @param[in] shape2d 2D tensor shape * @param[in] data_type DataType to use. Only supported: F32,F16 * @param[in] image_row_pitch Image row pitch (a.k.a. stride Y) to be used in the image2d object + * @param[in] image_type Image 2D type (@ref CLImage2DType) * * @return cl::Image2D object */ -cl::Image2D create_image2d_from_buffer(const cl::Context &ctx, const cl::Buffer &buffer, const TensorShape &shape2d, DataType data_type, size_t image_row_pitch); +cl::Image2D create_image2d_from_buffer(const cl::Context &ctx, + const cl::Buffer &buffer, + const TensorShape &shape2d, + DataType data_type, + size_t image_row_pitch, + CLImage2DType image_type); + +/** Check for CL error code and throw exception accordingly. + * + * @param[in] function_name The name of the CL function being called. + * @param[in] error_code The error returned by the CL function. + */ +void handle_cl_error(const std::string &function_name, cl_int error_code); -} // arm_compute +} // namespace arm_compute -#endif /* ARM_COMPUTE_CL_CLUTILS_H */ +#endif // ACL_SRC_CORE_CL_CLUTILS_H |