diff options
Diffstat (limited to 'src/core/CL/kernels')
-rw-r--r-- | src/core/CL/kernels/CLRemapKernel.cpp | 32 | ||||
-rw-r--r-- | src/core/CL/kernels/CLRemapKernel.h | 18 |
2 files changed, 31 insertions, 19 deletions
diff --git a/src/core/CL/kernels/CLRemapKernel.cpp b/src/core/CL/kernels/CLRemapKernel.cpp index 6edd744db7..7e3157c99d 100644 --- a/src/core/CL/kernels/CLRemapKernel.cpp +++ b/src/core/CL/kernels/CLRemapKernel.cpp @@ -32,8 +32,6 @@ #include "src/core/AccessWindowStatic.h" #include "src/core/helpers/WindowHelpers.h" -#include <algorithm> - namespace arm_compute { CLRemapKernel::CLRemapKernel() @@ -54,11 +52,18 @@ void CLRemapKernel::set_constant_border(unsigned int idx, const PixelValue &cons ICLKernel::add_argument<T>(idx, static_cast<T>(value)); } -Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_x, const ITensorInfo *map_y, ITensorInfo *output, RemapInfo info) +Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_x, const ITensorInfo *map_y, const ITensorInfo *output, RemapInfo info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, map_x, map_y, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8); + if(input->data_layout() == DataLayout::NCHW) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::F16); + } + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() != output->data_type(), "Input/output have different data types"); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(map_x, 1, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(map_y, 1, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.policy == InterpolationPolicy::AREA, "Area interpolation is not supported!"); @@ -68,7 +73,8 @@ Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_ void CLRemapKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *map_x, const ICLTensor *map_y, ICLTensor *output, RemapInfo info) { - CLRemapKernel::validate(input->info(), map_x->info(), map_y->info(), output->info(), info); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, map_x, map_y, output); + ARM_COMPUTE_ERROR_THROW_ON(CLRemapKernel::validate(input->info(), map_x->info(), map_y->info(), output->info(), info)); _input = input; _output = output; @@ -118,7 +124,19 @@ void CLRemapKernel::configure(const CLCompileContext &compile_context, const ICL _kernel.setArg<cl_float>(idx++, input_height); if(is_nhwc && is_constant_border) { - set_constant_border<uint8_t>(idx, info.constant_border_value); + switch(input->info()->data_type()) + { + case DataType::U8: + set_constant_border<uint8_t>(idx, info.constant_border_value); + break; + case DataType::F16: + static_assert(sizeof(cl_half) == sizeof(half), "Half must be same size as cl_half"); + static_assert(sizeof(cl_half) == 2, "Half must be 16 bit"); + set_constant_border<half>(idx, info.constant_border_value); + break; + default: + ARM_COMPUTE_ERROR("Data Type not handled"); + } } } diff --git a/src/core/CL/kernels/CLRemapKernel.h b/src/core/CL/kernels/CLRemapKernel.h index 1e3a4ad13f..93b0b4e660 100644 --- a/src/core/CL/kernels/CLRemapKernel.h +++ b/src/core/CL/kernels/CLRemapKernel.h @@ -49,28 +49,22 @@ public: /** Initialize the kernel's input, output and border mode. * * @param[in] compile_context The compile context to be used. - * @param[in] input Source tensor. Data types supported: U8. + * @param[in] input Source tensor. Data types supported: U8 (or F16 when layout is NHWC). * @param[in] map_x Map for X coordinates. Data types supported: F32. * @param[in] map_y Map for Y coordinates. Data types supported: F32. - * @param[out] output Destination tensor. Data types supported: U8. All but the lowest two dimensions must be the same size as in the input tensor, i.e. remapping is only performed within the XY-plane. + * @param[out] output Destination tensor. Data types supported: Same as @p input. All but the lowest two dimensions must be the same size as in the input tensor, i.e. remapping is only performed within the XY-plane. * @param[in] info RemapInfo struct: * - policy Interpolation policy to use. Only NEAREST and BILINEAR are supported. * - border_mode Border mode to use on the input tensor. Only CONSTANT and UNDEFINED are supported. * - constant_border_value Constant value to use for borders if border_mode is set to CONSTANT. */ void configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *map_x, const ICLTensor *map_y, ICLTensor *output, RemapInfo info); - /** Validate the kernel's input, output and border mode. + /** Checks if the kernel's input, output and border mode will lead to a valid configuration of @ref CLRemapKernel + * + * Similar to @ref CLRemapKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *map_x, const ICLTensor *map_y, ICLTensor *output, RemapInfo info) * - * @param[in] input Source tensor. Data types supported: U8. - * @param[in] map_x Map for X coordinates. Data types supported: F32. - * @param[in] map_y Map for Y coordinates. Data types supported: F32. - * @param[out] output Destination tensor. Data types supported: U8. All but the lowest two dimensions must be the same size as in the input tensor, i.e. remapping is only performed within the XY-plane. - * @param[in] info RemapInfo struct: - * - policy Interpolation policy to use. Only NEAREST and BILINEAR are supported. - * - border_mode Border mode to use on the input tensor. Only CONSTANT and UNDEFINED are supported. - * - constant_border_value Constant value to use for borders if border_mode is set to CONSTANT. */ - static Status validate(const ITensorInfo *input, const ITensorInfo *map_x, const ITensorInfo *map_y, ITensorInfo *output, RemapInfo info); + static Status validate(const ITensorInfo *input, const ITensorInfo *map_x, const ITensorInfo *map_y, const ITensorInfo *output, RemapInfo info); /** Function to set the constant value on fill border kernel depending on type. * * @param[in] idx Index of the kernel argument to set. |