diff options
Diffstat (limited to 'tests/validation/CPP/Utils.h')
-rw-r--r-- | tests/validation/CPP/Utils.h | 31 |
1 files changed, 28 insertions, 3 deletions
diff --git a/tests/validation/CPP/Utils.h b/tests/validation/CPP/Utils.h index 557d85f204..2d879c129b 100644 --- a/tests/validation/CPP/Utils.h +++ b/tests/validation/CPP/Utils.h @@ -47,9 +47,34 @@ T tensor_elem_at(const SimpleTensor<T> &in, Coordinates coord, BorderMode border template <typename T> T bilinear_policy(const SimpleTensor<T> &in, Coordinates id, float xn, float yn, BorderMode border_mode, T constant_border_value); -template <typename T1, typename T2, typename T3> -void apply_2d_spatial_filter(Coordinates coord, const SimpleTensor<T1> &in, SimpleTensor<T3> &out, const TensorShape &filter_shape, const T2 *filter_itr, float scale, BorderMode border_mode, - T1 constant_border_value = 0); +/* Apply 2D spatial filter on a single element of @p in at coordinates @p coord + * + * - filter sizes have to be odd number + * - Row major order of filter assumed + * - TO_ZERO rounding policy assumed + * - SATURATE convert policy assumed + */ +template <typename T, typename U, typename V> +void apply_2d_spatial_filter(Coordinates coord, const SimpleTensor<T> &src, SimpleTensor<U> &dst, const TensorShape &filter_shape, const V *filter_itr, double scale, BorderMode border_mode, + T constant_border_value = T(0)) +{ + double val = 0.; + const int x = coord.x(); + const int y = coord.y(); + for(int j = y - static_cast<int>(filter_shape[1] / 2); j <= y + static_cast<int>(filter_shape[1] / 2); ++j) + { + for(int i = x - static_cast<int>(filter_shape[0] / 2); i <= x + static_cast<int>(filter_shape[0] / 2); ++i) + { + coord.set(0, i); + coord.set(1, j); + val += static_cast<double>(*filter_itr) * tensor_elem_at(src, coord, border_mode, constant_border_value); + ++filter_itr; + } + } + coord.set(0, x); + coord.set(1, y); + dst[coord2index(src.shape(), coord)] = saturate_cast<U>(support::cpp11::trunc(val * scale)); +} RawTensor transpose(const RawTensor &src, int chunk_width = 1); } // namespace validation |