diff options
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r-- | tests/validation/TensorOperations.h | 29 |
1 files changed, 9 insertions, 20 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h index fce257540b..569559352a 100644 --- a/tests/validation/TensorOperations.h +++ b/tests/validation/TensorOperations.h @@ -201,6 +201,7 @@ void vector_matrix_multiply(const int8_t *in, const int8_t *weights, const int8_ } } +// Return a tensor element at a specified coordinate with different border modes template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0> T tensor_elem_at(const Tensor<T> &in, Coordinates &coord, BorderMode border_mode, T constant_border_value) { @@ -209,14 +210,10 @@ T tensor_elem_at(const Tensor<T> &in, Coordinates &coord, BorderMode border_mode const int width = static_cast<int>(in.shape().x()); const int height = static_cast<int>(in.shape().y()); - // If on border + // If coordinates beyond range of tensor's width or height if(x < 0 || y < 0 || x >= width || y >= height) { - if(border_mode == BorderMode::CONSTANT) - { - return constant_border_value; - } - else if(border_mode == BorderMode::REPLICATE) + if(border_mode == BorderMode::REPLICATE) { coord.set(0, std::max(0, std::min(x, width - 1))); coord.set(1, std::max(0, std::min(y, height - 1))); @@ -224,10 +221,7 @@ T tensor_elem_at(const Tensor<T> &in, Coordinates &coord, BorderMode border_mode } else { - // Return a random value if on border and border_mode == UNDEFINED - std::mt19937 gen(user_config.seed.get()); - std::uniform_int_distribution<T> distribution(0, 255); - return distribution(gen); + return constant_border_value; } } else @@ -257,8 +251,7 @@ void apply_2d_spatial_filter(Coordinates coord, const Tensor<T1> &in, Tensor<T3> { coord.set(0, i); coord.set(1, j); - double pixel_to_multiply = tensor_elem_at(in, coord, border_mode, constant_border_value); - val += static_cast<double>(*filter_itr) * pixel_to_multiply; + val += static_cast<double>(*filter_itr) * tensor_elem_at(in, coord, border_mode, constant_border_value); ++filter_itr; } } @@ -508,20 +501,16 @@ void bitwise_not(const Tensor<T> &in, Tensor<T> &out) } } -// 3-by-3 box filter +// Box3x3 filter template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type> -void box3x3(const Tensor<T> &in, Tensor<T> &out) +void box3x3(const Tensor<T> &in, Tensor<T> &out, BorderMode border_mode, T constant_border_value) { const std::array<T, 9> filter{ { 1, 1, 1, 1, 1, 1, 1, 1, 1 } }; - float scale = 1.f / static_cast<float>(filter.size()); - const ValidRegion valid_region = shape_to_valid_region_undefined_border(in.shape(), BorderSize(1)); + float scale = 1.f / static_cast<float>(filter.size()); for(int element_idx = 0; element_idx < in.num_elements(); ++element_idx) { const Coordinates id = index2coord(in.shape(), element_idx); - if(is_in_valid_region(valid_region, id)) - { - apply_2d_spatial_filter(id, in, out, TensorShape(3U, 3U), filter.data(), scale, BorderMode::UNDEFINED); - } + apply_2d_spatial_filter(id, in, out, TensorShape(3U, 3U), filter.data(), scale, border_mode, constant_border_value); } } |