aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CPP/Scale.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CPP/Scale.cpp')
-rw-r--r--tests/validation/CPP/Scale.cpp32
1 files changed, 25 insertions, 7 deletions
diff --git a/tests/validation/CPP/Scale.cpp b/tests/validation/CPP/Scale.cpp
index c368fa277a..727325f675 100644
--- a/tests/validation/CPP/Scale.cpp
+++ b/tests/validation/CPP/Scale.cpp
@@ -26,6 +26,7 @@
#include "Scale.h"
#include "Utils.h"
+#include "support/ToolchainSupport.h"
namespace arm_compute
{
@@ -36,7 +37,8 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, bool ceil_policy_scale)
+SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, T constant_border_value,
+ SamplingPolicy sampling_policy, bool ceil_policy_scale)
{
// Add 1 if ceil_policy_scale is true
const size_t round_value = ceil_policy_scale ? 1U : 0U;
@@ -66,8 +68,23 @@ SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, I
Coordinates id = index2coord(out.shape(), element_idx);
int idx = id.x();
int idy = id.y();
- float x_src = (idx + 0.5f) * wr - 0.5f;
- float y_src = (idy + 0.5f) * hr - 0.5f;
+ float x_src = 0;
+ float y_src = 0;
+
+ switch(sampling_policy)
+ {
+ case SamplingPolicy::TOP_LEFT:
+ x_src = idx * wr;
+ y_src = idy * hr;
+ break;
+ case SamplingPolicy::CENTER:
+ x_src = (idx + 0.5f) * wr - 0.5f;
+ y_src = (idy + 0.5f) * hr - 0.5f;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported sampling policy.");
+ break;
+ }
switch(policy)
{
@@ -152,12 +169,13 @@ SimpleTensor<T> scale(const SimpleTensor<T> &in, float scale_x, float scale_y, I
}
template SimpleTensor<uint8_t> scale(const SimpleTensor<uint8_t> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, uint8_t constant_border_value,
- bool ceil_policy_scale);
+ SamplingPolicy sampling_policy, bool ceil_policy_scale);
template SimpleTensor<int16_t> scale(const SimpleTensor<int16_t> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, int16_t constant_border_value,
- bool ceil_policy_scale);
-template SimpleTensor<half> scale(const SimpleTensor<half> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, half constant_border_value, bool ceil_policy_scale);
+ SamplingPolicy sampling_policy, bool ceil_policy_scale);
+template SimpleTensor<half> scale(const SimpleTensor<half> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, half constant_border_value,
+ SamplingPolicy sampling_policy, bool ceil_policy_scale);
template SimpleTensor<float> scale(const SimpleTensor<float> &src, float scale_x, float scale_y, InterpolationPolicy policy, BorderMode border_mode, float constant_border_value,
- bool ceil_policy_scale);
+ SamplingPolicy sampling_policy, bool ceil_policy_scale);
} // namespace reference
} // namespace validation
} // namespace test