aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CPP/Utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CPP/Utils.h')
-rw-r--r--tests/validation/CPP/Utils.h25
1 files changed, 24 insertions, 1 deletions
diff --git a/tests/validation/CPP/Utils.h b/tests/validation/CPP/Utils.h
index 2d879c129b..91d1afe1d7 100644
--- a/tests/validation/CPP/Utils.h
+++ b/tests/validation/CPP/Utils.h
@@ -41,8 +41,31 @@ namespace test
{
namespace validation
{
+// Return a tensor element at a specified coordinate with different border modes
template <typename T>
-T tensor_elem_at(const SimpleTensor<T> &in, Coordinates coord, BorderMode border_mode, T constant_border_value);
+T tensor_elem_at(const SimpleTensor<T> &src, Coordinates coord, BorderMode border_mode, T constant_border_value)
+{
+ const int x = coord.x();
+ const int y = coord.y();
+ const int width = src.shape().x();
+ const int height = src.shape().y();
+
+ // If coordinates beyond range of tensor's width or height
+ if(x < 0 || y < 0 || x >= width || y >= height)
+ {
+ if(border_mode == BorderMode::REPLICATE)
+ {
+ coord.set(0, std::max(0, std::min(x, width - 1)));
+ coord.set(1, std::max(0, std::min(y, height - 1)));
+ }
+ else
+ {
+ return constant_border_value;
+ }
+ }
+
+ return src[coord2index(src.shape(), coord)];
+}
template <typename T>
T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, T constant_border_value);