diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-02-26 09:58:13 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-03-05 15:15:15 +0000 |
commit | e8291acc1d9e89c9274d31f0d5bb4779eb95588c (patch) | |
tree | 5a0fef36d6daabe387174e55b60de54557c75291 /tests/validation/reference | |
parent | aa85cdf22802cb892d7fa422ca505a43d84adb38 (diff) | |
download | ComputeLibrary-e8291acc1d9e89c9274d31f0d5bb4779eb95588c.tar.gz |
COMPMID-3152: Initial Bfloat16 support
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: Ie6959e37e13731c86b2ee29392a99a293450a1b4
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2824
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/DepthConvertLayer.cpp | 25 | ||||
-rw-r--r-- | tests/validation/reference/DepthConvertLayer.h | 7 |
2 files changed, 27 insertions, 5 deletions
diff --git a/tests/validation/reference/DepthConvertLayer.cpp b/tests/validation/reference/DepthConvertLayer.cpp index 7da0011fbb..57eeb7f6f3 100644 --- a/tests/validation/reference/DepthConvertLayer.cpp +++ b/tests/validation/reference/DepthConvertLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -63,14 +63,14 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con return result; } -template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&!std::is_same<T1, T2>::value, int >::type > +template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&(!std::is_same<T1, T2>::value &&!std::is_same<T2, bfloat16>::value), int >::type > SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift) { SimpleTensor<T2> result(src.shape(), dt_out); ARM_COMPUTE_ERROR_ON(shift != 0); ARM_COMPUTE_UNUSED(policy, shift); - if(!is_floating_point<T2>::value) + if(!std::is_same<T2, bfloat16>::value && !is_floating_point<T2>::value) { // Always saturate on floats for(int i = 0; i < src.num_elements(); ++i) @@ -89,6 +89,21 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con return result; } +template < typename T1, typename T2, typename std::enable_if < std::is_same<T1, bfloat16>::value || std::is_same<T2, bfloat16>::value, int >::type > +SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift) +{ + SimpleTensor<T2> result(src.shape(), dt_out); + ARM_COMPUTE_ERROR_ON(shift != 0); + ARM_COMPUTE_UNUSED(policy, shift); + + for(int i = 0; i < src.num_elements(); ++i) + { + result[i] = static_cast<T2>(src[i]); + } + + return result; +} + // U8 template SimpleTensor<int8_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); @@ -143,6 +158,9 @@ template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<int32_t> &src, template SimpleTensor<half> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<float> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); +// BFLOAT16 +template SimpleTensor<float> depth_convert(const SimpleTensor<bfloat16> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); + // F16 template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<int8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); @@ -160,6 +178,7 @@ template SimpleTensor<int16_t> depth_convert(const SimpleTensor<float> &src, Dat template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<int32_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor<half> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); +template SimpleTensor<bfloat16> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); } // namespace reference } // namespace validation diff --git a/tests/validation/reference/DepthConvertLayer.h b/tests/validation/reference/DepthConvertLayer.h index f9f849b3f7..9513d07c34 100644 --- a/tests/validation/reference/DepthConvertLayer.h +++ b/tests/validation/reference/DepthConvertLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -38,7 +38,10 @@ namespace reference template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&!std::is_same<T1, T2>::value, int >::type = 0 > SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); -template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&!std::is_same<T1, T2>::value, int >::type = 0 > +template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&(!std::is_same<T1, T2>::value &&!std::is_same<T2, bfloat16>::value), int >::type = 0 > +SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); + +template < typename T1, typename T2, typename std::enable_if < std::is_same<T1, bfloat16>::value || std::is_same<T2, bfloat16>::value, int >::type = 0 > SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); } // namespace reference } // namespace validation |