aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/Scale.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/Scale.cpp')
-rw-r--r--tests/validation/reference/Scale.cpp47
1 files changed, 29 insertions, 18 deletions
diff --git a/tests/validation/reference/Scale.cpp b/tests/validation/reference/Scale.cpp
index 84f4fb83c1..63a2853c66 100644
--- a/tests/validation/reference/Scale.cpp
+++ b/tests/validation/reference/Scale.cpp
@@ -71,28 +71,25 @@ SimpleTensor<T> scale_core(const SimpleTensor<T> &in, float scale_x, float scale
float x_src = 0;
float y_src = 0;
- switch(sampling_policy)
- {
- case SamplingPolicy::TOP_LEFT:
- x_src = idx * wr;
- y_src = idy * hr;
- break;
- case SamplingPolicy::CENTER:
- x_src = (idx + 0.5f) * wr - 0.5f;
- y_src = (idy + 0.5f) * hr - 0.5f;
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported sampling policy.");
- break;
- }
-
switch(policy)
{
case InterpolationPolicy::NEAREST_NEIGHBOR:
{
- //Calculate the source coords without -0.5f is equivalent to round the x_scr/y_src coords
- x_src = (idx + 0.5f) * wr;
- y_src = (idy + 0.5f) * hr;
+ switch(sampling_policy)
+ {
+ case SamplingPolicy::TOP_LEFT:
+ x_src = std::floor(idx * wr);
+ y_src = std::floor(idy * hr);
+ break;
+ case SamplingPolicy::CENTER:
+ //Calculate the source coords without -0.5f is equivalent to round the x_scr/y_src coords
+ x_src = (idx + 0.5f) * wr;
+ y_src = (idy + 0.5f) * hr;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported sampling policy.");
+ }
+
id.set(0, x_src);
id.set(1, y_src);
@@ -105,6 +102,20 @@ SimpleTensor<T> scale_core(const SimpleTensor<T> &in, float scale_x, float scale
}
case InterpolationPolicy::BILINEAR:
{
+ switch(sampling_policy)
+ {
+ case SamplingPolicy::TOP_LEFT:
+ x_src = idx * wr;
+ y_src = idy * hr;
+ break;
+ case SamplingPolicy::CENTER:
+ x_src = (idx + 0.5f) * wr - 0.5f;
+ y_src = (idy + 0.5f) * hr - 0.5f;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported sampling policy.");
+ }
+
id.set(0, std::floor(x_src));
id.set(1, std::floor(y_src));
if(is_valid_pixel_index(x_src, y_src, width, height, border_size))