diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/DepthConvertLayer.cpp | 25 | ||||
-rw-r--r-- | tests/validation/reference/DepthConvertLayer.h | 7 |
2 files changed, 27 insertions, 5 deletions
diff --git a/tests/validation/reference/DepthConvertLayer.cpp b/tests/validation/reference/DepthConvertLayer.cpp index 7da0011fbb..57eeb7f6f3 100644 --- a/tests/validation/reference/DepthConvertLayer.cpp +++ b/tests/validation/reference/DepthConvertLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -63,14 +63,14 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con return result; } -template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&!std::is_same<T1, T2>::value, int >::type > +template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&(!std::is_same<T1, T2>::value &&!std::is_same<T2, bfloat16>::value), int >::type > SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift) { SimpleTensor<T2> result(src.shape(), dt_out); ARM_COMPUTE_ERROR_ON(shift != 0); ARM_COMPUTE_UNUSED(policy, shift); - if(!is_floating_point<T2>::value) + if(!std::is_same<T2, bfloat16>::value && !is_floating_point<T2>::value) { // Always saturate on floats for(int i = 0; i < src.num_elements(); ++i) @@ -89,6 +89,21 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con return result; } +template < typename T1, typename T2, typename std::enable_if < std::is_same<T1, bfloat16>::value || std::is_same<T2, bfloat16>::value, int >::type > +SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift) +{ + SimpleTensor<T2> result(src.shape(), dt_out); + ARM_COMPUTE_ERROR_ON(shift != 0); + ARM_COMPUTE_UNUSED(policy, shift); + + for(int i = 0; i < src.num_elements(); ++i) + { + result[i] = static_cast<T2>(src[i]); + } + + return result; +} + // U8 template SimpleTensor<int8_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); @@ -143,6 +158,9 @@ template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<int32_t> &src, template SimpleTensor<half> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<float> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); +// BFLOAT16 +template SimpleTensor<float> depth_convert(const SimpleTensor<bfloat16> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); + // F16 template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<int8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); @@ -160,6 +178,7 @@ template SimpleTensor<int16_t> depth_convert(const SimpleTensor<float> &src, Dat template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<int32_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<half> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); +template SimpleTensor<bfloat16> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); } // namespace reference } // namespace validation diff --git a/tests/validation/reference/DepthConvertLayer.h b/tests/validation/reference/DepthConvertLayer.h index f9f849b3f7..9513d07c34 100644 --- a/tests/validation/reference/DepthConvertLayer.h +++ b/tests/validation/reference/DepthConvertLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -38,7 +38,10 @@ namespace reference template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&!std::is_same<T1, T2>::value, int >::type = 0 > SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); -template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&!std::is_same<T1, T2>::value, int >::type = 0 > +template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&(!std::is_same<T1, T2>::value &&!std::is_same<T2, bfloat16>::value), int >::type = 0 > +SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); + +template < typename T1, typename T2, typename std::enable_if < std::is_same<T1, bfloat16>::value || std::is_same<T2, bfloat16>::value, int >::type = 0 > SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); } // namespace reference } // namespace validation |