diff options
author | Freddie Liardet <frederick.liardet@arm.com> | 2021-06-10 16:45:58 +0100 |
---|---|---|
committer | frederick.liardet <frederick.liardet@arm.com> | 2021-06-22 12:39:24 +0000 |
commit | ef5aac6c1e119e8db16a33332b5551829f409786 (patch) | |
tree | 78b7d34c9da20e8c6e4393981ada8e1253f239b2 /src/core/CL/kernels/CLRemapKernel.cpp | |
parent | 8266ae50a7da14ce27592f89181287be81969fd0 (diff) | |
download | ComputeLibrary-ef5aac6c1e119e8db16a33332b5551829f409786.tar.gz |
Add FP16 support to CLRemap
Add FP16 support to CLRemap when data layout is NHWC.
Add relevant tests for FP16 and validation.
Update NERemap function level to be consistent with CLRemap.
Add depreciation notice for uint_8 only function level methods.
Resolves: COMPMID-4335
Signed-off-by: Freddie Liardet <frederick.liardet@arm.com>
Change-Id: If05f06801aef7a169b73ff1ebe760a42f11ca05c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5816
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLRemapKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLRemapKernel.cpp | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/src/core/CL/kernels/CLRemapKernel.cpp b/src/core/CL/kernels/CLRemapKernel.cpp index 6edd744db7..7e3157c99d 100644 --- a/src/core/CL/kernels/CLRemapKernel.cpp +++ b/src/core/CL/kernels/CLRemapKernel.cpp @@ -32,8 +32,6 @@ #include "src/core/AccessWindowStatic.h" #include "src/core/helpers/WindowHelpers.h" -#include <algorithm> - namespace arm_compute { CLRemapKernel::CLRemapKernel() @@ -54,11 +52,18 @@ void CLRemapKernel::set_constant_border(unsigned int idx, const PixelValue &cons ICLKernel::add_argument<T>(idx, static_cast<T>(value)); } -Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_x, const ITensorInfo *map_y, ITensorInfo *output, RemapInfo info) +Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_x, const ITensorInfo *map_y, const ITensorInfo *output, RemapInfo info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, map_x, map_y, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8); + if(input->data_layout() == DataLayout::NCHW) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::F16); + } + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() != output->data_type(), "Input/output have different data types"); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(map_x, 1, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(map_y, 1, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.policy == InterpolationPolicy::AREA, "Area interpolation is not supported!"); @@ -68,7 +73,8 @@ Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_ void CLRemapKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *map_x, const ICLTensor *map_y, ICLTensor *output, RemapInfo info) { - CLRemapKernel::validate(input->info(), map_x->info(), map_y->info(), output->info(), info); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, map_x, map_y, output); + ARM_COMPUTE_ERROR_THROW_ON(CLRemapKernel::validate(input->info(), map_x->info(), map_y->info(), output->info(), info)); _input = input; _output = output; @@ -118,7 +124,19 @@ void CLRemapKernel::configure(const CLCompileContext &compile_context, const ICL _kernel.setArg<cl_float>(idx++, input_height); if(is_nhwc && is_constant_border) { - set_constant_border<uint8_t>(idx, info.constant_border_value); + switch(input->info()->data_type()) + { + case DataType::U8: + set_constant_border<uint8_t>(idx, info.constant_border_value); + break; + case DataType::F16: + static_assert(sizeof(cl_half) == sizeof(half), "Half must be same size as cl_half"); + static_assert(sizeof(cl_half) == 2, "Half must be 16 bit"); + set_constant_border<half>(idx, info.constant_border_value); + break; + default: + ARM_COMPUTE_ERROR("Data Type not handled"); + } } } |