aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorOperations.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r--tests/validation/TensorOperations.h29
1 files changed, 9 insertions, 20 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index fce257540b..569559352a 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -201,6 +201,7 @@ void vector_matrix_multiply(const int8_t *in, const int8_t *weights, const int8_
}
}
+// Return a tensor element at a specified coordinate with different border modes
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
T tensor_elem_at(const Tensor<T> &in, Coordinates &coord, BorderMode border_mode, T constant_border_value)
{
@@ -209,14 +210,10 @@ T tensor_elem_at(const Tensor<T> &in, Coordinates &coord, BorderMode border_mode
const int width = static_cast<int>(in.shape().x());
const int height = static_cast<int>(in.shape().y());
- // If on border
+ // If coordinates beyond range of tensor's width or height
if(x < 0 || y < 0 || x >= width || y >= height)
{
- if(border_mode == BorderMode::CONSTANT)
- {
- return constant_border_value;
- }
- else if(border_mode == BorderMode::REPLICATE)
+ if(border_mode == BorderMode::REPLICATE)
{
coord.set(0, std::max(0, std::min(x, width - 1)));
coord.set(1, std::max(0, std::min(y, height - 1)));
@@ -224,10 +221,7 @@ T tensor_elem_at(const Tensor<T> &in, Coordinates &coord, BorderMode border_mode
}
else
{
- // Return a random value if on border and border_mode == UNDEFINED
- std::mt19937 gen(user_config.seed.get());
- std::uniform_int_distribution<T> distribution(0, 255);
- return distribution(gen);
+ return constant_border_value;
}
}
else
@@ -257,8 +251,7 @@ void apply_2d_spatial_filter(Coordinates coord, const Tensor<T1> &in, Tensor<T3>
{
coord.set(0, i);
coord.set(1, j);
- double pixel_to_multiply = tensor_elem_at(in, coord, border_mode, constant_border_value);
- val += static_cast<double>(*filter_itr) * pixel_to_multiply;
+ val += static_cast<double>(*filter_itr) * tensor_elem_at(in, coord, border_mode, constant_border_value);
++filter_itr;
}
}
@@ -508,20 +501,16 @@ void bitwise_not(const Tensor<T> &in, Tensor<T> &out)
}
}
-// 3-by-3 box filter
+// Box3x3 filter
template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
-void box3x3(const Tensor<T> &in, Tensor<T> &out)
+void box3x3(const Tensor<T> &in, Tensor<T> &out, BorderMode border_mode, T constant_border_value)
{
const std::array<T, 9> filter{ { 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
- float scale = 1.f / static_cast<float>(filter.size());
- const ValidRegion valid_region = shape_to_valid_region_undefined_border(in.shape(), BorderSize(1));
+ float scale = 1.f / static_cast<float>(filter.size());
for(int element_idx = 0; element_idx < in.num_elements(); ++element_idx)
{
const Coordinates id = index2coord(in.shape(), element_idx);
- if(is_in_valid_region(valid_region, id))
- {
- apply_2d_spatial_filter(id, in, out, TensorShape(3U, 3U), filter.data(), scale, BorderMode::UNDEFINED);
- }
+ apply_2d_spatial_filter(id, in, out, TensorShape(3U, 3U), filter.data(), scale, border_mode, constant_border_value);
}
}