aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/scale.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/scale.cl')
-rw-r--r--src/core/CL/cl_kernels/scale.cl22
1 files changed, 19 insertions, 3 deletions
diff --git a/src/core/CL/cl_kernels/scale.cl b/src/core/CL/cl_kernels/scale.cl
index 5ac6443c98..499f9ea53f 100644
--- a/src/core/CL/cl_kernels/scale.cl
+++ b/src/core/CL/cl_kernels/scale.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2018 ARM Limited.
+ * Copyright (c) 2016-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,10 +33,19 @@
*/
inline const float8 transform_nearest(const float2 coord, const float2 scale)
{
+#ifdef SAMPLING_POLICY_TOP_LEFT
+ const float4 in_x_coords = (float4)(coord.s0, 1 + coord.s0, 2 + coord.s0, 3 + coord.s0);
+ const float4 new_x = in_x_coords * (float4)(scale.s0);
+ const float4 new_y = (float4)(coord.s1 * scale.s1);
+ return (float8)(new_x.s0, new_y.s0, new_x.s1, new_y.s1, new_x.s2, new_y.s2, new_x.s3, new_y.s3);
+#elif SAMPLING_POLICY_CENTER
const float4 in_x_coords = (float4)(coord.s0, 1 + coord.s0, 2 + coord.s0, 3 + coord.s0);
const float4 new_x = (in_x_coords + ((float4)(0.5f))) * (float4)(scale.s0);
const float4 new_y = (float4)((coord.s1 + 0.5f) * scale.s1);
return (float8)(new_x.s0, new_y.s0, new_x.s1, new_y.s1, new_x.s2, new_y.s2, new_x.s3, new_y.s3);
+#else /* SAMPLING_POLICY */
+#error("Unsupported sampling policy");
+#endif /* SAMPLING_POLICY */
}
/** Transforms four 2D coordinates. This is used to map the output coordinates to the input coordinates.
@@ -172,8 +181,15 @@ __kernel void scale_nearest_neighbour_nhwc(
Tensor4D in = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(in, 0);
Tensor4D out = CONVERT_TO_TENSOR4D_STRUCT(out, DEPTH_OUT);
- const float new_x = (get_global_id(1) + 0.5f) * scale_x;
- const float new_y = ((get_global_id(2) % DEPTH_OUT) + 0.5f) * scale_y;
+#ifdef SAMPLING_POLICY_TOP_LEFT
+ const float new_x = get_global_id(1) * scale_x;
+ const float new_y = (get_global_id(2) % DEPTH_OUT) * scale_y;
+#elif SAMPLING_POLICY_CENTER
+ const float new_x = (get_global_id(1) + 0.5f) * scale_x;
+ const float new_y = ((get_global_id(2) % DEPTH_OUT) + 0.5f) * scale_y;
+#else /* SAMPLING_POLICY */
+#error("Unsupported sampling policy");
+#endif /* SAMPLING_POLICY */
const float clamped_x = clamp(new_x, 0.0f, input_width - 1);
const float clamped_y = clamp(new_y, 0.0f, input_height - 1);