diff options
Diffstat (limited to 'tests/validation/TensorVisitors.h')
-rw-r--r-- | tests/validation/TensorVisitors.h | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tests/validation/TensorVisitors.h b/tests/validation/TensorVisitors.h index c58b9a69c0..168e212a65 100644 --- a/tests/validation/TensorVisitors.h +++ b/tests/validation/TensorVisitors.h @@ -45,6 +45,28 @@ namespace validation { namespace tensor_visitors { +// Min max location visitor +struct min_max_location_visitor : public boost::static_visitor<> +{ +public: + explicit min_max_location_visitor(int32_t &min, int32_t &max, Coordinates2DArray &min_loc, Coordinates2DArray &max_loc, uint32_t &min_count, uint32_t &max_count) + : _min(min), _max(max), _min_loc(min_loc), _max_loc(max_loc), _min_count(min_count), _max_count(max_count) + { + } + template <typename T1> + void operator()(const Tensor<T1> &in) const + { + tensor_operations::min_max_location(in, _min, _max, _min_loc, _max_loc, _min_count, _max_count); + } + +private: + int32_t &_min; + int32_t &_max; + Coordinates2DArray &_min_loc; + Coordinates2DArray &_max_loc; + uint32_t &_min_count; + uint32_t &_max_count; +}; // Absolute Difference visitor struct absolute_difference_visitor : public boost::static_visitor<> { |