aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLRemapKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLRemapKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLRemapKernel.cpp32
1 files changed, 25 insertions, 7 deletions
diff --git a/src/core/CL/kernels/CLRemapKernel.cpp b/src/core/CL/kernels/CLRemapKernel.cpp
index 6edd744db7..7e3157c99d 100644
--- a/src/core/CL/kernels/CLRemapKernel.cpp
+++ b/src/core/CL/kernels/CLRemapKernel.cpp
@@ -32,8 +32,6 @@
#include "src/core/AccessWindowStatic.h"
#include "src/core/helpers/WindowHelpers.h"
-#include <algorithm>
-
namespace arm_compute
{
CLRemapKernel::CLRemapKernel()
@@ -54,11 +52,18 @@ void CLRemapKernel::set_constant_border(unsigned int idx, const PixelValue &cons
ICLKernel::add_argument<T>(idx, static_cast<T>(value));
}
-Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_x, const ITensorInfo *map_y, ITensorInfo *output, RemapInfo info)
+Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_x, const ITensorInfo *map_y, const ITensorInfo *output, RemapInfo info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, map_x, map_y, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
+ if(input->data_layout() == DataLayout::NCHW)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::F16);
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() != output->data_type(), "Input/output have different data types");
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(map_x, 1, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(map_y, 1, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.policy == InterpolationPolicy::AREA, "Area interpolation is not supported!");
@@ -68,7 +73,8 @@ Status CLRemapKernel::validate(const ITensorInfo *input, const ITensorInfo *map_
void CLRemapKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *map_x, const ICLTensor *map_y, ICLTensor *output, RemapInfo info)
{
- CLRemapKernel::validate(input->info(), map_x->info(), map_y->info(), output->info(), info);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, map_x, map_y, output);
+ ARM_COMPUTE_ERROR_THROW_ON(CLRemapKernel::validate(input->info(), map_x->info(), map_y->info(), output->info(), info));
_input = input;
_output = output;
@@ -118,7 +124,19 @@ void CLRemapKernel::configure(const CLCompileContext &compile_context, const ICL
_kernel.setArg<cl_float>(idx++, input_height);
if(is_nhwc && is_constant_border)
{
- set_constant_border<uint8_t>(idx, info.constant_border_value);
+ switch(input->info()->data_type())
+ {
+ case DataType::U8:
+ set_constant_border<uint8_t>(idx, info.constant_border_value);
+ break;
+ case DataType::F16:
+ static_assert(sizeof(cl_half) == sizeof(half), "Half must be same size as cl_half");
+ static_assert(sizeof(cl_half) == 2, "Half must be 16 bit");
+ set_constant_border<half>(idx, info.constant_border_value);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data Type not handled");
+ }
}
}