diff options
Diffstat (limited to 'tests/validation/CPP/Utils.h')
-rw-r--r-- | tests/validation/CPP/Utils.h | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tests/validation/CPP/Utils.h b/tests/validation/CPP/Utils.h index 2d879c129b..91d1afe1d7 100644 --- a/tests/validation/CPP/Utils.h +++ b/tests/validation/CPP/Utils.h @@ -41,8 +41,31 @@ namespace test { namespace validation { +// Return a tensor element at a specified coordinate with different border modes template <typename T> -T tensor_elem_at(const SimpleTensor<T> &in, Coordinates coord, BorderMode border_mode, T constant_border_value); +T tensor_elem_at(const SimpleTensor<T> &src, Coordinates coord, BorderMode border_mode, T constant_border_value) +{ + const int x = coord.x(); + const int y = coord.y(); + const int width = src.shape().x(); + const int height = src.shape().y(); + + // If coordinates beyond range of tensor's width or height + if(x < 0 || y < 0 || x >= width || y >= height) + { + if(border_mode == BorderMode::REPLICATE) + { + coord.set(0, std::max(0, std::min(x, width - 1))); + coord.set(1, std::max(0, std::min(y, height - 1))); + } + else + { + return constant_border_value; + } + } + + return src[coord2index(src.shape(), coord)]; +} template <typename T> T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, T constant_border_value); |