diff options
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r-- | tests/validation/TensorOperations.h | 90 |
1 files changed, 27 insertions, 63 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h index 0502f53186..e90635f0d4 100644 --- a/tests/validation/TensorOperations.h +++ b/tests/validation/TensorOperations.h @@ -518,94 +518,58 @@ void box3x3(const Tensor<T> &in, Tensor<T> &out, BorderMode border_mode, T const } // Depth conversion -template <typename T1, typename T2> +template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_floating_point<T2>::value, int >::type = 0 > void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) { - ARM_COMPUTE_ERROR("The conversion is not supported"); -} - -template <> -void depth_convert<int8_t, float>(const Tensor<int8_t> &in, Tensor<float> &out, ConvertPolicy policy, uint32_t shift) -{ - const int8_t fixed_point_position = static_cast<int8_t>(in.fixed_point_position()); - for(int i = 0; i < in.num_elements(); ++i) - { - out[i] = static_cast<float>(in[i]) * (1.0f / (1 << fixed_point_position)); - } -} - -template <> -void depth_convert<float, int8_t>(const Tensor<float> &in, Tensor<int8_t> &out, ConvertPolicy policy, uint32_t shift) -{ - const int8_t fixed_point_position = static_cast<int8_t>(in.fixed_point_position()); - for(int i = 0; i < in.num_elements(); ++i) - { - float val = in[i] * (1 << fixed_point_position) + 0.5f; - out[i] = ((policy == ConvertPolicy::SATURATE) ? saturate_cast<int8_t>(val) : static_cast<int8_t>(val)); - } -} + using namespace fixed_point_arithmetic; -template <> -void depth_convert<uint8_t, uint16_t>(const Tensor<uint8_t> &in, Tensor<uint16_t> &out, ConvertPolicy policy, uint32_t shift) -{ + const int fixed_point_position = in.fixed_point_position(); for(int i = 0; i < in.num_elements(); ++i) { - out[i] = static_cast<uint16_t>(in[i]) << shift; + out[i] = static_cast<float>(fixed_point<T1>(in[i], fixed_point_position, true)); } } -template <> -void depth_convert<uint8_t, int16_t>(const Tensor<uint8_t> &in, Tensor<int16_t> &out, ConvertPolicy policy, uint32_t shift) +template < typename T1, typename T2, typename std::enable_if < std::is_floating_point<T1>::value &&std::is_integral<T2>::value, int >::type = 0 > +void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) { - for(int i = 0; i < in.num_elements(); ++i) - { - out[i] = static_cast<int16_t>(in[i]) << shift; - } -} + using namespace fixed_point_arithmetic; -template <> -void depth_convert<uint8_t, int32_t>(const Tensor<uint8_t> &in, Tensor<int32_t> &out, ConvertPolicy policy, uint32_t shift) -{ + const int fixed_point_position = out.fixed_point_position(); for(int i = 0; i < in.num_elements(); ++i) { - out[i] = static_cast<int32_t>(in[i]) << shift; + out[i] = fixed_point<T2>(in[i], fixed_point_position).raw(); } } -template <> -void depth_convert<uint16_t, uint8_t>(const Tensor<uint16_t> &in, Tensor<uint8_t> &out, ConvertPolicy policy, uint32_t shift) +template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&std::is_integral<T2>::value, int >::type = 0 > +void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) { - for(int i = 0; i < in.num_elements(); ++i) + // Up-casting + if(std::numeric_limits<T1>::digits <= std::numeric_limits<T2>::digits) { - uint16_t val = in[i] >> shift; - out[i] = ((policy == ConvertPolicy::SATURATE) ? saturate_cast<uint8_t>(val) : static_cast<uint8_t>(val)); + for(int i = 0; i < in.num_elements(); ++i) + { + out[i] = static_cast<T2>(in[i]) << shift; + } } -} - -template <> -void depth_convert<uint16_t, uint32_t>(const Tensor<uint16_t> &in, Tensor<uint32_t> &out, ConvertPolicy policy, uint32_t shift) -{ - for(int i = 0; i < in.num_elements(); ++i) + // Down-casting + else { - out[i] = static_cast<uint32_t>(in[i]) << shift; + for(int i = 0; i < in.num_elements(); ++i) + { + T1 val = in[i] >> shift; + out[i] = ((policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(val) : static_cast<T2>(val)); + } } } -template <> -void depth_convert<int16_t, uint8_t>(const Tensor<int16_t> &in, Tensor<uint8_t> &out, ConvertPolicy policy, uint32_t shift) -{ - for(int i = 0; i < in.num_elements(); ++i) - { - int16_t val = in[i] >> shift; - out[i] = ((policy == ConvertPolicy::SATURATE) ? saturate_cast<uint8_t>(val) : static_cast<uint8_t>(val)); - } -} -template <> -void depth_convert<int16_t, int32_t>(const Tensor<int16_t> &in, Tensor<int32_t> &out, ConvertPolicy policy, uint32_t shift) +template < typename T1, typename T2, typename std::enable_if < std::is_floating_point<T1>::value &&std::is_floating_point<T2>::value, int >::type = 0 > +void depth_convert(const Tensor<T1> &in, Tensor<T2> &out, ConvertPolicy policy, uint32_t shift) { for(int i = 0; i < in.num_elements(); ++i) { - out[i] = static_cast<int32_t>(in[i]) << shift; + out[i] = static_cast<T2>(in[i]); } } |