diff options
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r-- | tests/validation/TensorOperations.h | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h index 9e6f5cf5d1..882c9e07e1 100644 --- a/tests/validation/TensorOperations.h +++ b/tests/validation/TensorOperations.h @@ -519,7 +519,7 @@ void box3x3(const Tensor<T> &in, Tensor<T> &out, BorderMode border_mode, T const } // Depth conversion -template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&is_floating_point<T2>::value, int >::type = 0 > +template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_floating_point<T2>::value, int >::type = 0 > void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) { using namespace fixed_point_arithmetic; @@ -531,7 +531,7 @@ void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, } } -template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&std::is_integral<T2>::value, int >::type = 0 > +template < typename T1, typename T2, typename std::enable_if < std::is_floating_point<T1>::value &&std::is_integral<T2>::value, int >::type = 0 > void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) { using namespace fixed_point_arithmetic; @@ -543,7 +543,7 @@ void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, } } -template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_integral<T2>::value, int >::type = 0 > +template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_integral<T2>::value &&!std::is_same<T1, T2>::value, int >::type = 0 > void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) { // Up-casting @@ -565,6 +565,26 @@ void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, } } +template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_integral<T2>::value &&std::is_same<T1, T2>::value, int >::type = 0 > +void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) +{ + using namespace fixed_point_arithmetic; + bool is_in_place = (&in == &out); + + const int fixed_point_position_in = in.fixed_point_position(); + const int fixed_point_position_out = (is_in_place) ? static_cast<int>(shift) : out.fixed_point_position(); + + if(!is_in_place || (fixed_point_position_in != fixed_point_position_out)) + { + for(int i = 0; i < in.num_elements(); ++i) + { + auto x = fixed_point<T2>(in[i], fixed_point_position_in, true); + x.rescale(fixed_point_position_out); + out[i] = x.raw(); + } + } +} + template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&is_floating_point<T2>::value, int >::type = 0 > void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) { |