aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLBoundingBoxTransformKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLBoundingBoxTransformKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLBoundingBoxTransformKernel.cpp12
1 files changed, 8 insertions, 4 deletions
diff --git a/src/core/CL/kernels/CLBoundingBoxTransformKernel.cpp b/src/core/CL/kernels/CLBoundingBoxTransformKernel.cpp
index 09f3d33f5b..bff28e3ed9 100644
--- a/src/core/CL/kernels/CLBoundingBoxTransformKernel.cpp
+++ b/src/core/CL/kernels/CLBoundingBoxTransformKernel.cpp
@@ -39,7 +39,7 @@ namespace arm_compute
{
namespace
{
-Status validate_arguments(const ITensorInfo *boxes, const ITensorInfo *pred_boxes, const ITensorInfo *deltas)
+Status validate_arguments(const ITensorInfo *boxes, const ITensorInfo *pred_boxes, const ITensorInfo *deltas, const BoundingBoxTransformInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(boxes, pred_boxes, deltas);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(boxes, DataType::F32, DataType::F16);
@@ -56,6 +56,7 @@ Status validate_arguments(const ITensorInfo *boxes, const ITensorInfo *pred_boxe
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(pred_boxes, deltas);
ARM_COMPUTE_RETURN_ERROR_ON(pred_boxes->num_dimensions() > 2);
}
+ ARM_COMPUTE_RETURN_ERROR_ON(info.scale() <= 0);
return Status{};
}
} // namespace
@@ -70,6 +71,8 @@ void CLBoundingBoxTransformKernel::configure(const ICLTensor *boxes, ICLTensor *
ARM_COMPUTE_ERROR_ON_NULLPTR(boxes, pred_boxes, deltas);
auto_init_if_empty(*pred_boxes->info(), *deltas->info());
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(boxes->info(), pred_boxes->info(), deltas->info(), info));
+
// Set instance variables
_boxes = boxes;
_pred_boxes = pred_boxes;
@@ -90,7 +93,9 @@ void CLBoundingBoxTransformKernel::configure(const ICLTensor *boxes, ICLTensor *
build_opts.add_option("-DIMG_WIDTH=" + support::cpp11::to_string(img_w));
build_opts.add_option("-DIMG_HEIGHT=" + support::cpp11::to_string(img_h));
build_opts.add_option("-DBOX_FIELDS=" + support::cpp11::to_string(4));
- build_opts.add_option_if(info.apply_scale(), "-DSCALE=" + float_to_string_with_full_precision(info.scale()));
+ build_opts.add_option("-DSCALE_BEFORE=" + float_to_string_with_full_precision(info.scale()));
+ build_opts.add_option_if(info.apply_scale(), "-DSCALE_AFTER=" + float_to_string_with_full_precision(info.scale()));
+ build_opts.add_option_if(info.correct_transform_coords(), "-DOFFSET=1");
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("bounding_box_transform", build_opts.options()));
@@ -103,8 +108,7 @@ void CLBoundingBoxTransformKernel::configure(const ICLTensor *boxes, ICLTensor *
Status CLBoundingBoxTransformKernel::validate(const ITensorInfo *boxes, const ITensorInfo *pred_boxes, const ITensorInfo *deltas, const BoundingBoxTransformInfo &info)
{
- ARM_COMPUTE_UNUSED(info);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(boxes, pred_boxes, deltas));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(boxes, pred_boxes, deltas, info));
return Status{};
}