diff options
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r-- | tests/validation/TensorOperations.h | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h index 4905e05732..0430d59d33 100644 --- a/tests/validation/TensorOperations.h +++ b/tests/validation/TensorOperations.h @@ -619,6 +619,25 @@ void gaussian3x3(const Tensor<T> &in, Tensor<T> &out, BorderMode border_mode, T } } +// Gaussian5x5 filter +template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type> +void gaussian5x5(const Tensor<T> &in, Tensor<T> &out, BorderMode border_mode, T constant_border_value) +{ + const std::array<T, 25> filter{ { + 1, 4, 6, 4, 1, + 4, 16, 24, 16, 4, + 6, 24, 36, 24, 6, + 4, 16, 24, 16, 4, + 1, 4, 6, 4, 1 + } }; + const float scale = 1.f / 256.f; + for(int element_idx = 0; element_idx < in.num_elements(); ++element_idx) + { + const Coordinates id = index2coord(in.shape(), element_idx); + apply_2d_spatial_filter(id, in, out, TensorShape(5U, 5U), filter.data(), scale, border_mode, constant_border_value); + } +} + // Matrix multiplication for floating point type template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type * = nullptr> void gemm(const Tensor<T> &in1, const Tensor<T> &in2, const Tensor<T> &in3, Tensor<T> &out, float alpha, float beta) |