From e8291acc1d9e89c9274d31f0d5bb4779eb95588c Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 26 Feb 2020 09:58:13 +0000 Subject: COMPMID-3152: Initial Bfloat16 support Signed-off-by: Georgios Pinitas Change-Id: Ie6959e37e13731c86b2ee29392a99a293450a1b4 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2824 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Michalis Spyrou --- tests/validation/reference/DepthConvertLayer.cpp | 25 +++++++++++++++++++++--- tests/validation/reference/DepthConvertLayer.h | 7 +++++-- 2 files changed, 27 insertions(+), 5 deletions(-) (limited to 'tests/validation/reference') diff --git a/tests/validation/reference/DepthConvertLayer.cpp b/tests/validation/reference/DepthConvertLayer.cpp index 7da0011fbb..57eeb7f6f3 100644 --- a/tests/validation/reference/DepthConvertLayer.cpp +++ b/tests/validation/reference/DepthConvertLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -63,14 +63,14 @@ SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, Con return result; } -template < typename T1, typename T2, typename std::enable_if < is_floating_point::value &&!std::is_same::value, int >::type > +template < typename T1, typename T2, typename std::enable_if < is_floating_point::value &&(!std::is_same::value &&!std::is_same::value), int >::type > SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift) { SimpleTensor result(src.shape(), dt_out); ARM_COMPUTE_ERROR_ON(shift != 0); ARM_COMPUTE_UNUSED(policy, shift); - if(!is_floating_point::value) + if(!std::is_same::value && !is_floating_point::value) { // Always saturate on floats for(int i = 0; i < src.num_elements(); ++i) @@ -89,6 +89,21 @@ SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, Con return result; } +template < typename T1, typename T2, typename std::enable_if < std::is_same::value || std::is_same::value, int >::type > +SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift) +{ + SimpleTensor result(src.shape(), dt_out); + ARM_COMPUTE_ERROR_ON(shift != 0); + ARM_COMPUTE_UNUSED(policy, shift); + + for(int i = 0; i < src.num_elements(); ++i) + { + result[i] = static_cast(src[i]); + } + + return result; +} + // U8 template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); @@ -143,6 +158,9 @@ template SimpleTensor depth_convert(const SimpleTensor &src, template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); +// BFLOAT16 +template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); + // F16 template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); @@ -160,6 +178,7 @@ template SimpleTensor depth_convert(const SimpleTensor &src, Dat template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); +template SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); } // namespace reference } // namespace validation diff --git a/tests/validation/reference/DepthConvertLayer.h b/tests/validation/reference/DepthConvertLayer.h index f9f849b3f7..9513d07c34 100644 --- a/tests/validation/reference/DepthConvertLayer.h +++ b/tests/validation/reference/DepthConvertLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -38,7 +38,10 @@ namespace reference template < typename T1, typename T2, typename std::enable_if < std::is_integral::value &&!std::is_same::value, int >::type = 0 > SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); -template < typename T1, typename T2, typename std::enable_if < is_floating_point::value &&!std::is_same::value, int >::type = 0 > +template < typename T1, typename T2, typename std::enable_if < is_floating_point::value &&(!std::is_same::value &&!std::is_same::value), int >::type = 0 > +SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); + +template < typename T1, typename T2, typename std::enable_if < std::is_same::value || std::is_same::value, int >::type = 0 > SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, ConvertPolicy policy, uint32_t shift); } // namespace reference } // namespace validation -- cgit v1.2.1