aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-02-26 09:58:13 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-03-05 15:15:15 +0000
commite8291acc1d9e89c9274d31f0d5bb4779eb95588c (patch)
tree5a0fef36d6daabe387174e55b60de54557c75291
parentaa85cdf22802cb892d7fa422ca505a43d84adb38 (diff)
downloadComputeLibrary-e8291acc1d9e89c9274d31f0d5bb4779eb95588c.tar.gz
COMPMID-3152: Initial Bfloat16 support
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: Ie6959e37e13731c86b2ee29392a99a293450a1b4 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2824 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
-rw-r--r--SConstruct2
-rw-r--r--arm_compute/core/CPP/Validate.h45
-rw-r--r--arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h25
-rw-r--r--arm_compute/core/NEON/wrapper/intrinsics/cvt.h19
-rw-r--r--arm_compute/core/PixelValue.h21
-rw-r--r--arm_compute/core/Types.h3
-rw-r--r--arm_compute/core/Utils.h27
-rw-r--r--arm_compute/runtime/NEON/functions/NEDepthConvertLayer.h23
-rw-r--r--src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp93
-rw-r--r--src/core/Utils.cpp13
-rw-r--r--src/runtime/CPUUtils.cpp6
-rw-r--r--support/Bfloat16.h140
-rw-r--r--support/ToolchainSupport.h14
-rw-r--r--tests/AssetsLibrary.h23
-rw-r--r--tests/Utils.h5
-rw-r--r--tests/validation/Helpers.cpp14
-rw-r--r--tests/validation/NEON/DepthConvertLayer.cpp40
-rw-r--r--tests/validation/reference/DepthConvertLayer.cpp25
-rw-r--r--tests/validation/reference/DepthConvertLayer.h7
-rw-r--r--utils/TypePrinter.h3
20 files changed, 501 insertions, 47 deletions
diff --git a/SConstruct b/SConstruct
index 83b1409c4c..0076a365e8 100644
--- a/SConstruct
+++ b/SConstruct
@@ -205,7 +205,7 @@ elif 'v8' in env['arch']:
env.Append(CXXFLAGS = ['-march=armv8-a'])
if 'v8.6-a' in env['arch']:
- env.Append(CXXFLAGS = ['-DV8P6'])
+ env.Append(CPPDEFINES = ['V8P6', 'ARM_COMPUTE_FORCE_BF16'])
elif 'x86' in env['arch']:
if env['estate'] == '32':
diff --git a/arm_compute/core/CPP/Validate.h b/arm_compute/core/CPP/Validate.h
index f195a31d00..dfee9de86e 100644
--- a/arm_compute/core/CPP/Validate.h
+++ b/arm_compute/core/CPP/Validate.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -48,6 +48,26 @@ inline Status error_on_unsupported_cpu_fp16(const char *function, const char *fi
return Status {};
}
+/** Return an error if the data type of the passed tensor info is BFLOAT16 and BFLOAT16 support is not compiled in.
+ *
+ * @param[in] function Function in which the error occurred.
+ * @param[in] file Name of the file where the error occurred.
+ * @param[in] line Line on which the error occurred.
+ * @param[in] tensor_info Tensor info to validate.
+ *
+ * @return Status
+ */
+inline Status error_on_unsupported_cpu_bf16(const char *function, const char *file, const int line,
+ const ITensorInfo *tensor_info)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor_info == nullptr, function, file, line);
+#if !(defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16))
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC_MSG(tensor_info->data_type() == DataType::BFLOAT16,
+ function, file, line, "This CPU architecture does not support BFloat16 data type, you need v8.6 or above");
+#endif /* !(defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)) */
+ return Status {};
+}
+
/** Return an error if the data type of the passed tensor is FP16 and FP16 support is not compiled in.
*
* @param[in] function Function in which the error occurred.
@@ -65,10 +85,33 @@ inline Status error_on_unsupported_cpu_fp16(const char *function, const char *fi
return Status{};
}
+/** Return an error if the data type of the passed tensor is BFLOAT16 and BFLOAT16 support is not compiled in.
+ *
+ * @param[in] function Function in which the error occurred.
+ * @param[in] file Name of the file where the error occurred.
+ * @param[in] line Line on which the error occurred.
+ * @param[in] tensor Tensor to validate.
+ *
+ * @return Status
+ */
+inline Status error_on_unsupported_cpu_bf16(const char *function, const char *file, const int line,
+ const ITensor *tensor)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_LOC(tensor == nullptr, function, file, line);
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_unsupported_cpu_bf16(function, file, line, tensor->info()));
+ return Status{};
+}
+
#define ARM_COMPUTE_ERROR_ON_CPU_F16_UNSUPPORTED(tensor) \
ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_unsupported_cpu_fp16(__func__, __FILE__, __LINE__, tensor))
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor) \
ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_unsupported_cpu_fp16(__func__, __FILE__, __LINE__, tensor))
+
+#define ARM_COMPUTE_ERROR_ON_CPU_BF16_UNSUPPORTED(tensor) \
+ ARM_COMPUTE_ERROR_THROW_ON(::arm_compute::error_on_unsupported_cpu_bf16(__func__, __FILE__, __LINE__, tensor))
+
+#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(tensor) \
+ ARM_COMPUTE_RETURN_ON_ERROR(::arm_compute::error_on_unsupported_cpu_bf16(__func__, __FILE__, __LINE__, tensor))
} // namespace arm_compute
#endif /* ARM_COMPUTE_CPP_VALIDATE_H */
diff --git a/arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h b/arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h
index df4102cb86..5cda3203ed 100644
--- a/arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEDepthConvertLayerKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -55,24 +55,25 @@ public:
* Valid conversions Input -> Output :
*
* - QASYMM8_SIGNED -> S16, S32, F32, F16
- * - QASYMM8 -> U16, S16, S32, F32, F16
- * - U8 -> U16, S16, S32, F32, F16
- * - U16 -> U8, U32
- * - S16 -> QASYMM8_SIGNED, U8, S32
- * - F16 -> QASYMM8_SIGNED, QASYMM8, F32, S32, U8
- * - S32 -> QASYMM8_SIGNED, QASYMM8, F16, F32, U8
- * - F32 -> QASYMM8_SIGNED, QASYMM8, F16, S32, U8
+ * - QASYMM8 -> U16, S16, S32, F32, F16
+ * - U8 -> U16, S16, S32, F32, F16
+ * - U16 -> U8, U32
+ * - S16 -> QASYMM8_SIGNED, U8, S32
+ * - BFLOAT16 -> F32
+ * - F16 -> QASYMM8_SIGNED, QASYMM8, F32, S32, U8
+ * - S32 -> QASYMM8_SIGNED, QASYMM8, F16, F32, U8
+ * - F32 -> QASYMM8_SIGNED, QASYMM8, BFLOAT16, F16, S32, U8
*
- * @param[in] input The input tensor to convert. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/F16/F32.
- * @param[out] output The output tensor. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/F16/F32.
+ * @param[in] input The input tensor to convert. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/BFLOAT16/F16/F32.
+ * @param[out] output The output tensor. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/BFLOAT16/F16/F32.
* @param[in] policy Conversion policy.
* @param[in] shift (Optional) Value for down/up conversions. Must be 0 <= shift < 8.
*/
void configure(const ITensor *input, ITensor *output, ConvertPolicy policy, uint32_t shift = 0);
/** Static function to check if given info will lead to a valid configuration of @ref NEDepthConvertLayerKernel
*
- * @param[in] input Source tensor info. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/F16/F32.
- * @param[in] output Destination tensor info. Data type supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/F16/F32.
+ * @param[in] input Source tensor info. Data types supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/BFLOAT16/F16/F32.
+ * @param[in] output Destination tensor info. Data type supported: QASYMM8_SIGNED/QASYMM8/U8/U16/S16/U32/S32/BFLOAT16/F16/F32.
* @param[in] policy Conversion policy
* @param[in] shift (Optional) Value for down/up conversions. Must be 0 <= shift < 8.
*
diff --git a/arm_compute/core/NEON/wrapper/intrinsics/cvt.h b/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
index 1f22e09a11..5ea9a5dedd 100644
--- a/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
+++ b/arm_compute/core/NEON/wrapper/intrinsics/cvt.h
@@ -56,6 +56,25 @@ vcvt(const float32x4_t &a)
return vcvtq_s32_f32(a);
}
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+/** Convert 2x128-bit floating point vectors into 1x128-bit bfloat16 vector
+ *
+ * @param[in] inptr Pointer to the input memory to load values from
+ * @param[in,out] outptr Pointer to the output memory to store values to
+ */
+inline void vcvt_bf16_f32(const float *inptr, uint16_t *outptr)
+{
+ __asm __volatile(
+ "ldp q0, q1, [%[inptr]]\n"
+ ".inst 0xea16800\n" // BFCVTN v0, v0
+ ".inst 0x4ea16820\n" // BFCVTN2 v0, v1
+ "str q0, [%[outptr]]\n"
+ : [inptr] "+r"(inptr)
+ : [outptr] "r"(outptr)
+ : "v0", "v1", "memory");
+}
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+
} // namespace wrapper
} // namespace arm_compute
#endif /* ARM_COMPUTE_WRAPPER_CVT_H */
diff --git a/arm_compute/core/PixelValue.h b/arm_compute/core/PixelValue.h
index 31bc55098a..337ccbc3f7 100644
--- a/arm_compute/core/PixelValue.h
+++ b/arm_compute/core/PixelValue.h
@@ -89,6 +89,9 @@ public:
case DataType::S64:
value.s64 = static_cast<int64_t>(v);
break;
+ case DataType::BFLOAT16:
+ value.bf16 = static_cast<bfloat16>(v);
+ break;
case DataType::F16:
value.f16 = static_cast<half>(v);
break;
@@ -174,6 +177,15 @@ public:
{
value.s64 = v;
}
+ /** Initialize the union with a BFLOAT16 pixel value
+ *
+ * @param[in] v F16 value.
+ */
+ PixelValue(bfloat16 v)
+ : PixelValue()
+ {
+ value.bf16 = v;
+ }
/** Initialize the union with a F16 pixel value
*
* @param[in] v F16 value.
@@ -214,6 +226,7 @@ public:
double f64; /**< Single channel double */
float f32; /**< Single channel float 32 */
half f16; /**< Single channel F16 */
+ bfloat16 bf16; /**< Single channel brain floating-point number */
uint8_t u8; /**< Single channel U8 */
int8_t s8; /**< Single channel S8 */
uint16_t u16; /**< Single channel U16 */
@@ -285,6 +298,14 @@ public:
{
v = value.s64;
}
+ /** Interpret the pixel value as a BFLOAT16
+ *
+ * @param[out] v Returned value
+ */
+ void get(bfloat16 &v) const
+ {
+ v = value.bf16;
+ }
/** Interpret the pixel value as a F16
*
* @param[out] v Returned value
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index cf689d757c..b6409879bb 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -30,6 +30,7 @@
#include "arm_compute/core/Strides.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/utils/misc/Macros.h"
+#include "support/Bfloat16.h"
#include "support/Half.h"
#include <cmath>
@@ -58,6 +59,7 @@ enum class Format
U16, /**< 1 channel, 1 U16 per channel */
S32, /**< 1 channel, 1 S32 per channel */
U32, /**< 1 channel, 1 U32 per channel */
+ BFLOAT16, /**< 16-bit brain floating-point number */
F16, /**< 1 channel, 1 F16 per channel */
F32, /**< 1 channel, 1 F32 per channel */
UV88, /**< 2 channel, 1 U8 per channel */
@@ -89,6 +91,7 @@ enum class DataType
S32, /**< signed 32-bit number */
U64, /**< unsigned 64-bit number */
S64, /**< signed 64-bit number */
+ BFLOAT16, /**< 16-bit brain floating-point number */
F16, /**< 16-bit floating-point number */
F32, /**< 32-bit floating-point number */
F64, /**< 64-bit floating-point number */
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 4a3b01d21f..8577046af0 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -114,6 +114,7 @@ inline size_t data_size_from_type(DataType data_type)
case DataType::S16:
case DataType::QSYMM16:
case DataType::QASYMM16:
+ case DataType::BFLOAT16:
case DataType::F16:
return 2;
case DataType::F32:
@@ -146,6 +147,7 @@ inline size_t pixel_size_from_format(Format format)
return 1;
case Format::U16:
case Format::S16:
+ case Format::BFLOAT16:
case Format::F16:
case Format::UV88:
case Format::YUYV422:
@@ -191,6 +193,7 @@ inline size_t element_size_from_data_type(DataType dt)
case DataType::S16:
case DataType::QSYMM16:
case DataType::QASYMM16:
+ case DataType::BFLOAT16:
case DataType::F16:
return 2;
case DataType::U32:
@@ -228,6 +231,8 @@ inline DataType data_type_from_format(Format format)
return DataType::U32;
case Format::S32:
return DataType::S32;
+ case Format::BFLOAT16:
+ return DataType::BFLOAT16;
case Format::F16:
return DataType::F16;
case Format::F32:
@@ -260,6 +265,7 @@ inline int plane_idx_from_channel(Format format, Channel channel)
case Format::S16:
case Format::U32:
case Format::S32:
+ case Format::BFLOAT16:
case Format::F16:
case Format::F32:
case Format::UV88:
@@ -447,6 +453,7 @@ inline size_t num_planes_from_format(Format format)
case Format::U16:
case Format::S32:
case Format::U32:
+ case Format::BFLOAT16:
case Format::F16:
case Format::F32:
case Format::RGB888:
@@ -481,6 +488,7 @@ inline size_t num_channels_from_format(Format format)
case Format::S16:
case Format::U32:
case Format::S32:
+ case Format::BFLOAT16:
case Format::F16:
case Format::F32:
return 1;
@@ -531,6 +539,7 @@ inline DataType get_promoted_data_type(DataType dt)
case DataType::QSYMM8_PER_CHANNEL:
case DataType::QSYMM16:
case DataType::QASYMM16:
+ case DataType::BFLOAT16:
case DataType::F16:
case DataType::U32:
case DataType::S32:
@@ -596,6 +605,12 @@ inline std::tuple<PixelValue, PixelValue> get_min_max(DataType dt)
max = PixelValue(std::numeric_limits<int32_t>::max());
break;
}
+ case DataType::BFLOAT16:
+ {
+ min = PixelValue(bfloat16::lowest());
+ max = PixelValue(bfloat16::max());
+ break;
+ }
case DataType::F16:
{
min = PixelValue(std::numeric_limits<half>::lowest());
@@ -1284,6 +1299,8 @@ bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = Quantization
const auto val_s32 = static_cast<int32_t>(val);
return ((val_s32 == val) && val_s32 >= std::numeric_limits<int32_t>::lowest() && val_s32 <= std::numeric_limits<int32_t>::max());
}
+ case DataType::BFLOAT16:
+ return (val >= bfloat16::lowest() && val <= bfloat16::max());
case DataType::F16:
return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
case DataType::F32:
@@ -1323,6 +1340,11 @@ void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int
// We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
s << std::right << static_cast<T>(ptr[i]) << element_delim;
}
+ else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
+ {
+ // We use T instead of print_type here is because the std::is_floating_point<bfloat> returns false and then the print_type becomes int.
+ s << std::right << float(ptr[i]) << element_delim;
+ }
else
{
s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
@@ -1357,6 +1379,11 @@ int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, u
// We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
ss << static_cast<T>(ptr[i]);
}
+ else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
+ {
+ // We use T instead of print_type here is because the std::is_floating_point<bfloat> returns false and then the print_type becomes int.
+ ss << float(ptr[i]);
+ }
else
{
ss << static_cast<print_type>(ptr[i]);
diff --git a/arm_compute/runtime/NEON/functions/NEDepthConvertLayer.h b/arm_compute/runtime/NEON/functions/NEDepthConvertLayer.h
index 43a256ebe2..b784480887 100644
--- a/arm_compute/runtime/NEON/functions/NEDepthConvertLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEDepthConvertLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,23 +47,24 @@ public:
*
* Valid conversions Input -> Output :
*
- * - QASYMM8 -> F16, F32
- * - U8 -> U16, S16, S32
- * - U16 -> U8, U32
- * - S16 -> U8, S32
- * - F16 -> QASYMM8, F32
- * - F32 -> QASYMM8, F16
+ * - QASYMM8 -> F16, F32
+ * - U8 -> U16, S16, S32
+ * - U16 -> U8, U32
+ * - S16 -> U8, S32
+ * - BFLOAT16 -> F32
+ * - F16 -> QASYMM8, F32
+ * - F32 -> QASYMM8, F16, BFLOAT16
*
- * @param[in] input The input tensor to convert. Data types supported: QASYMM8/U8/U16/S16/F16/F32.
- * @param[out] output The output tensor. Data types supported: QASYMM8/U8/U16/S16/U32/S32/F16/F32.
+ * @param[in] input The input tensor to convert. Data types supported: QASYMM8/U8/U16/S16/BFLOAT16/F16/F32.
+ * @param[out] output The output tensor. Data types supported: QASYMM8/U8/U16/S16/U32/S32/BFLOAT16/F16/F32.
* @param[in] policy Conversion policy.
* @param[in] shift (Optional) Value for down/up conversions. Must be 0 <= shift < 8.
*/
void configure(const ITensor *input, ITensor *output, ConvertPolicy policy, uint32_t shift = 0);
/** Static function to check if given info will lead to a valid configuration of @ref NEDepthConvertLayer
*
- * @param[in] input Source tensor info. Data types supported: QASYMM8/U8/U16/S16/F16/F32.
- * @param[in] output Destination tensor info. Data type supported: QASYMM8/U8/U16/S16/U32/S32/F16/F32.
+ * @param[in] input Source tensor info. Data types supported: QASYMM8/U8/U16/S16/BFLOAT16/F16/F32.
+ * @param[in] output Destination tensor info. Data type supported: QASYMM8/U8/U16/S16/U32/S32/BFLOAT16/F16/F32.
* @param[in] policy Conversion policy.
* @param[in] shift (Optional) Value for down/up conversions. Must be 0 <= shift < 8.
*
diff --git a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
index f824f7ac58..79dc2cb585 100644
--- a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
@@ -33,7 +33,7 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/SaturateCast.h"
-#include <arm_neon.h>
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
using namespace arm_compute;
@@ -43,11 +43,16 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, C
{
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(output);
ARM_COMPUTE_UNUSED(policy);
ARM_COMPUTE_RETURN_ERROR_ON(input == output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::F16, DataType::F32, DataType::S32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16,
- DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
+ DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::F32, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
+ DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::U32, DataType::S32, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(shift >= 8);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QASYMM8_SIGNED && (output->data_type() != DataType::S16 && output->data_type() != DataType::S32
@@ -68,15 +73,18 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, C
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S16 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::U8 && output->data_type() != DataType::S32),
"Only data_types supported [in] S16 -> [out] U8, S32");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::BFLOAT16 && output->data_type() != DataType::F32,
+ "Only data_types supported [in] BFLOAT16 -> [out] F32");
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F16 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
&& output->data_type() != DataType::U8
&& output->data_type() != DataType::F32 && output->data_type() != DataType::S32),
"Only data_types supported [in] F16 -> [out] QASYMM8, F32, S32, U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F32 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
- && output->data_type() != DataType::F16
+ && output->data_type() != DataType::F16 && output->data_type() != DataType::BFLOAT16
&& output->data_type() != DataType::S32 && output->data_type() != DataType::U8),
- "Only data_types supported [in] F32 -> [out] QASYMM8, F16, S32, U8");
+ "Only data_types supported [in] F32 -> [out] QASYMM8, BFLOAT16, F16, S32, U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
&& output->data_type() != DataType::F16
@@ -786,6 +794,52 @@ void NEDepthConvertLayerKernel::run(const Window &window, const ThreadInfo &info
}
break;
}
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+ case DataType::BFLOAT16:
+ switch(_output->info()->data_type())
+ {
+ case DataType::F32:
+ {
+ /* Up-conversion BFLOAT16 -> F32 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<const bfloat16 *>(input.ptr());
+ const auto output_ptr = reinterpret_cast<float *>(output.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint16x8x2_t texels =
+ {
+ {
+ vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())),
+ vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8)
+ }
+ };
+
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()),
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16)));
+ }
+
+ for(; x < window_end_x; ++x)
+ {
+ *(output_ptr + x) = float(*(input_ptr + x));
+ }
+ },
+ input, output);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Output data type unsupported");
+ }
+ break;
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
switch(_output->info()->data_type())
@@ -980,6 +1034,33 @@ void NEDepthConvertLayerKernel::run(const Window &window, const ThreadInfo &info
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+ case DataType::BFLOAT16:
+ {
+ /* Down-conversion F32 -> BFLOAT16 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<const float *>(input.ptr());
+ const auto output_ptr = reinterpret_cast<bfloat16 *>(output.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(input.ptr()),
+ reinterpret_cast<uint16_t *>(output.ptr()));
+ wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(input.ptr()) + 8,
+ reinterpret_cast<uint16_t *>(output.ptr()) + 8);
+ }
+
+ for(; x < window_end_x; ++x)
+ {
+ *(output_ptr + x) = *(input_ptr + x);
+ }
+ },
+ input, output);
+ break;
+ }
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
case DataType::S32:
{
const float scale_s = 1.f / (1 << _shift);
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index bca5a31914..fb86d78cb7 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -510,12 +510,15 @@ void arm_compute::print_consecutive_elements(std::ostream &s, DataType dt, const
case DataType::S32:
print_consecutive_elements_impl<int32_t>(s, reinterpret_cast<const int32_t *>(ptr), n, stream_width, element_delim);
break;
- case DataType::F32:
- print_consecutive_elements_impl<float>(s, reinterpret_cast<const float *>(ptr), n, stream_width, element_delim);
+ case DataType::BFLOAT16:
+ print_consecutive_elements_impl<bfloat16>(s, reinterpret_cast<const bfloat16 *>(ptr), n, stream_width, element_delim);
break;
case DataType::F16:
print_consecutive_elements_impl<half>(s, reinterpret_cast<const half *>(ptr), n, stream_width, element_delim);
break;
+ case DataType::F32:
+ print_consecutive_elements_impl<float>(s, reinterpret_cast<const float *>(ptr), n, stream_width, element_delim);
+ break;
default:
ARM_COMPUTE_ERROR("Undefined element size for given data type");
}
@@ -542,10 +545,12 @@ int arm_compute::max_consecutive_elements_display_width(std::ostream &s, DataTyp
return max_consecutive_elements_display_width_impl<uint32_t>(s, reinterpret_cast<const uint32_t *>(ptr), n);
case DataType::S32:
return max_consecutive_elements_display_width_impl<int32_t>(s, reinterpret_cast<const int32_t *>(ptr), n);
- case DataType::F32:
- return max_consecutive_elements_display_width_impl<float>(s, reinterpret_cast<const float *>(ptr), n);
+ case DataType::BFLOAT16:
+ return max_consecutive_elements_display_width_impl<bfloat16>(s, reinterpret_cast<const bfloat16 *>(ptr), n);
case DataType::F16:
return max_consecutive_elements_display_width_impl<half>(s, reinterpret_cast<const half *>(ptr), n);
+ case DataType::F32:
+ return max_consecutive_elements_display_width_impl<float>(s, reinterpret_cast<const float *>(ptr), n);
default:
ARM_COMPUTE_ERROR("Undefined element size for given data type");
}
diff --git a/src/runtime/CPUUtils.cpp b/src/runtime/CPUUtils.cpp
index 5860720d3b..e632787e86 100644
--- a/src/runtime/CPUUtils.cpp
+++ b/src/runtime/CPUUtils.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -402,7 +402,7 @@ unsigned int get_threads_hint()
{
unsigned int num_threads_hint = 1;
-#ifndef BARE_METAL
+#if !defined(BARE_METAL)
std::map<std::string, unsigned int> cpu_part_occurrence_map;
// CPU part regex
@@ -447,7 +447,7 @@ unsigned int get_threads_hint()
// Set thread hint
num_threads_hint = cpu_part_occurrence_map.empty() ? std::thread::hardware_concurrency() : min_common_cores->second;
-#endif /* BARE_METAL */
+#endif /* !defined(BARE_METAL) */
return num_threads_hint;
}
diff --git a/support/Bfloat16.h b/support/Bfloat16.h
new file mode 100644
index 0000000000..d897e42643
--- /dev/null
+++ b/support/Bfloat16.h
@@ -0,0 +1,140 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_BFLOAT16_H
+#define ARM_COMPUTE_BFLOAT16_H
+
+#include <cstdint>
+
+namespace arm_compute
+{
+namespace
+{
+/** Convert float to bfloat16
+ *
+ * @param[in] v Floating-point value to convert to bfloat
+ *
+ * @return Converted value
+ */
+inline uint16_t float_to_bf16(const float v)
+{
+ const uint32_t *fromptr = reinterpret_cast<const uint32_t *>(&v);
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+ uint16_t res;
+
+ __asm __volatile(
+ "ldr s0, [%[fromptr]]\n"
+ ".inst 0x1e634000\n" // BFCVT h0, s0
+ "str h0, [%[toptr]]\n"
+ :
+ : [fromptr] "r"(fromptr), [toptr] "r"(&res)
+ : "v0", "memory");
+#else /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+ uint16_t res = (*fromptr >> 16);
+ const uint16_t error = (*fromptr & 0x0000ffff);
+ uint16_t bf_l = res & 0x0001;
+ if((error > 0x8000) || ((error == 0x8000) && (bf_l != 0)))
+ {
+ res += 1;
+ }
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+ return res;
+}
+
+/** Convert bfloat16 to float
+ *
+ * @param[in] v Bfloat16 value to convert to float
+ *
+ * @return Converted value
+ */
+inline float bf16_to_float(const uint16_t &v)
+{
+ const uint32_t lv = (v << 16);
+ const float *fp = reinterpret_cast<const float *>(&lv);
+
+ return *fp;
+}
+}
+
+/** Brain floating point representation class */
+class bfloat16
+{
+public:
+ /** Default Constructor */
+ bfloat16()
+ : value(0)
+ {
+ }
+ /** Constructor
+ *
+ * @param[in] v Floating-point value
+ */
+ explicit bfloat16(float v)
+ : value(float_to_bf16(v))
+ {
+ }
+ /** Assignment operator
+ *
+ * @param[in] v Floating point value to assign
+ *
+ * @return The updated object
+ */
+ bfloat16 &operator=(float v)
+ {
+ value = float_to_bf16(v);
+ return *this;
+ }
+ /** Floating point conversion operator
+ *
+ * @return Floating point representation of the value
+ */
+ operator float() const
+ {
+ return bf16_to_float(value);
+ }
+ /** Lowest representative value
+ *
+ * @return Returns the lowest finite value representable by bfloat16
+ */
+ static bfloat16 lowest()
+ {
+ bfloat16 val;
+ val.value = 0xFF7F;
+ return val;
+ }
+ /** Largest representative value
+ *
+ * @return Returns the largest finite value representable by bfloat16
+ */
+ static bfloat16 max()
+ {
+ bfloat16 val;
+ val.value = 0x7F7F;
+ return val;
+ }
+
+private:
+ uint16_t value;
+};
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_BFLOAT16_H */ \ No newline at end of file
diff --git a/support/ToolchainSupport.h b/support/ToolchainSupport.h
index f90d65c9d9..923a9cbfe0 100644
--- a/support/ToolchainSupport.h
+++ b/support/ToolchainSupport.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -39,6 +39,7 @@
#include <arm_neon.h>
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+#include "support/Bfloat16.h"
#include "support/Half.h"
namespace arm_compute
@@ -428,6 +429,12 @@ inline __fp16 lowest<__fp16>()
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+template <>
+inline bfloat16 lowest<bfloat16>()
+{
+ return bfloat16::lowest();
+}
+
// std::isfinite
template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
inline bool isfinite(T value)
@@ -439,6 +446,11 @@ inline bool isfinite(half_float::half value)
{
return half_float::isfinite(value);
}
+
+inline bool isfinite(bfloat16 value)
+{
+ return std::isfinite(float(value));
+}
} // namespace cpp11
namespace cpp14
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h
index c4892748f4..e625c37505 100644
--- a/tests/AssetsLibrary.h
+++ b/tests/AssetsLibrary.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -712,6 +712,13 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t
fill(tensor, distribution_s64, seed_offset);
break;
}
+ case DataType::BFLOAT16:
+ {
+ // It doesn't make sense to check [-inf, inf], so hard code it to a big number
+ std::uniform_real_distribution<float> distribution_bf16(-1000.f, 1000.f);
+ fill(tensor, distribution_bf16, seed_offset);
+ break;
+ }
case DataType::F16:
{
// It doesn't make sense to check [-inf, inf], so hard code it to a big number
@@ -810,6 +817,14 @@ void AssetsLibrary::fill_tensor_uniform_ranged(T
fill(tensor, distribution_s32, seed_offset);
break;
}
+ case DataType::BFLOAT16:
+ {
+ // It doesn't make sense to check [-inf, inf], so hard code it to a big number
+ const auto converted_pairs = detail::convert_range_pair<float>(excluded_range_pairs);
+ RangedUniformDistribution<float> distribution_bf16(-1000.f, 1000.f, converted_pairs);
+ fill(tensor, distribution_bf16, seed_offset);
+ break;
+ }
case DataType::F16:
{
// It doesn't make sense to check [-inf, inf], so hard code it to a big number
@@ -896,6 +911,12 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t
fill(tensor, distribution_s64, seed_offset);
break;
}
+ case DataType::BFLOAT16:
+ {
+ std::uniform_real_distribution<float> distribution_bf16(low, high);
+ fill(tensor, distribution_bf16, seed_offset);
+ break;
+ }
case DataType::F16:
{
std::uniform_real_distribution<float> distribution_f16(low, high);
diff --git a/tests/Utils.h b/tests/Utils.h
index 154d265cf9..3dc317f528 100644
--- a/tests/Utils.h
+++ b/tests/Utils.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -383,6 +383,9 @@ void store_value_with_data_type(void *ptr, T value, DataType data_type)
case DataType::S64:
*reinterpret_cast<int64_t *>(ptr) = value;
break;
+ case DataType::BFLOAT16:
+ *reinterpret_cast<bfloat16 *>(ptr) = bfloat16(value);
+ break;
case DataType::F16:
*reinterpret_cast<half *>(ptr) = value;
break;
diff --git a/tests/validation/Helpers.cpp b/tests/validation/Helpers.cpp
index afefee77be..4da9742c2a 100644
--- a/tests/validation/Helpers.cpp
+++ b/tests/validation/Helpers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -212,6 +212,18 @@ SimpleTensor<float> convert_from_symmetric(const SimpleTensor<int16_t> &src)
return dst;
}
+SimpleTensor<float> convert_from_bfloat16(const SimpleTensor<int16_t> &src)
+{
+ SimpleTensor<float> dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() };
+ return dst;
+}
+
+SimpleTensor<int16_t> convert_to_bfloat(const SimpleTensor<float> &src)
+{
+ SimpleTensor<int16_t> dst{ src.shape(), DataType::BFLOAT16, 1, QuantizationInfo(), src.data_layout() };
+ return dst;
+}
+
template <typename T>
void matrix_multiply(const SimpleTensor<T> &a, const SimpleTensor<T> &b, SimpleTensor<T> &out)
{
diff --git a/tests/validation/NEON/DepthConvertLayer.cpp b/tests/validation/NEON/DepthConvertLayer.cpp
index b7de8fd9bc..163f539659 100644
--- a/tests/validation/NEON/DepthConvertLayer.cpp
+++ b/tests/validation/NEON/DepthConvertLayer.cpp
@@ -56,12 +56,14 @@ const auto DepthConvertLayerU16toU8Dataset = combine(framework::dataset::ma
const auto DepthConvertLayerU16toU32Dataset = combine(framework::dataset::make("DataType", DataType::U16), framework::dataset::make("DataType", DataType::U32));
const auto DepthConvertLayerS16toU8Dataset = combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8));
const auto DepthConvertLayerS16toS32Dataset = combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::S32));
+const auto DepthConvertLayerBF16toF32Dataset = combine(framework::dataset::make("DataType", DataType::BFLOAT16), framework::dataset::make("DataType", DataType::F32));
const auto DepthConvertLayerF16toU8Dataset = combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::U8));
const auto DepthConvertLayerF16toF32Dataset = combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F32));
const auto DepthConvertLayerF16toS32Dataset = combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::S32));
const auto DepthConvertLayerF32toF16Dataset = combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F16));
const auto DepthConvertLayerF32toS32Dataset = combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::S32));
const auto DepthConvertLayerF32toU8Dataset = combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::U8));
+const auto DepthConvertLayerF32toBF16Dataset = combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::BFLOAT16));
const auto DepthConvertLayerS32toF32Dataset = combine(framework::dataset::make("DataType", DataType::S32), framework::dataset::make("DataType", DataType::F32));
const auto DepthConvertLayerS32toQASYMM8Dataset = combine(framework::dataset::make("DataType", DataType::S32), framework::dataset::make("DataType", DataType::QASYMM8));
@@ -127,6 +129,8 @@ using NEDepthConvertLayerToU8Fixture = DepthConvertLayerValidationFixture<Tensor
template <typename T>
using NEDepthConvertLayerToU32Fixture = DepthConvertLayerValidationFixture<Tensor, Accessor, NEDepthConvertLayer, T, uint32_t>;
template <typename T>
+using NEDepthConvertLayerToBF16Fixture = DepthConvertLayerValidationFixture<Tensor, Accessor, NEDepthConvertLayer, T, bfloat16>;
+template <typename T>
using NEDepthConvertLayerToF16Fixture = DepthConvertLayerValidationFixture<Tensor, Accessor, NEDepthConvertLayer, T, half>;
template <typename T>
using NEDepthConvertLayerToF32Fixture = DepthConvertLayerValidationFixture<Tensor, Accessor, NEDepthConvertLayer, T, float>;
@@ -340,6 +344,42 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEDepthConvertLayerToS32Fixture<int16_t>, frame
}
TEST_SUITE_END() // S16_to_S32
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+TEST_SUITE(BFLOAT16_to_F32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthConvertLayerToF32Fixture<bfloat16>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), DepthConvertLayerBF16toF32Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
+ DepthConvertLayerZeroShiftDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEDepthConvertLayerToF32Fixture<bfloat16>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), DepthConvertLayerBF16toF32Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
+ DepthConvertLayerZeroShiftDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // BFLOAT16_to_F32
+
+TEST_SUITE(F32_to_BFLOAT16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthConvertLayerToBF16Fixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), DepthConvertLayerF32toBF16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
+ DepthConvertLayerZeroShiftDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEDepthConvertLayerToBF16Fixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), DepthConvertLayerF32toBF16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
+ DepthConvertLayerZeroShiftDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // F32_to_BFLOAT16
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(F16_to_QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthConvertLayerToQASYMM8Fixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
diff --git a/tests/validation/reference/DepthConvertLayer.cpp b/tests/validation/reference/DepthConvertLayer.cpp
index 7da0011fbb..57eeb7f6f3 100644
--- a/tests/validation/reference/DepthConvertLayer.cpp
+++ b/tests/validation/reference/DepthConvertLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -63,14 +63,14 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con
return result;
}
-template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&!std::is_same<T1, T2>::value, int >::type >
+template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&(!std::is_same<T1, T2>::value &&!std::is_same<T2, bfloat16>::value), int >::type >
SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift)
{
SimpleTensor<T2> result(src.shape(), dt_out);
ARM_COMPUTE_ERROR_ON(shift != 0);
ARM_COMPUTE_UNUSED(policy, shift);
- if(!is_floating_point<T2>::value)
+ if(!std::is_same<T2, bfloat16>::value && !is_floating_point<T2>::value)
{
// Always saturate on floats
for(int i = 0; i < src.num_elements(); ++i)
@@ -89,6 +89,21 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con
return result;
}
+template < typename T1, typename T2, typename std::enable_if < std::is_same<T1, bfloat16>::value || std::is_same<T2, bfloat16>::value, int >::type >
+SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift)
+{
+ SimpleTensor<T2> result(src.shape(), dt_out);
+ ARM_COMPUTE_ERROR_ON(shift != 0);
+ ARM_COMPUTE_UNUSED(policy, shift);
+
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ result[i] = static_cast<T2>(src[i]);
+ }
+
+ return result;
+}
+
// U8
template SimpleTensor<int8_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<uint16_t> depth_convert(const SimpleTensor<uint8_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
@@ -143,6 +158,9 @@ template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<int32_t> &src,
template SimpleTensor<half> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<float> depth_convert(const SimpleTensor<int32_t> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+// BFLOAT16
+template SimpleTensor<float> depth_convert(const SimpleTensor<bfloat16> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
// F16
template SimpleTensor<uint8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<int8_t> depth_convert(const SimpleTensor<half> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
@@ -160,6 +178,7 @@ template SimpleTensor<int16_t> depth_convert(const SimpleTensor<float> &src, Dat
template SimpleTensor<uint32_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<int32_t> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
template SimpleTensor<half> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+template SimpleTensor<bfloat16> depth_convert(const SimpleTensor<float> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/DepthConvertLayer.h b/tests/validation/reference/DepthConvertLayer.h
index f9f849b3f7..9513d07c34 100644
--- a/tests/validation/reference/DepthConvertLayer.h
+++ b/tests/validation/reference/DepthConvertLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,7 +38,10 @@ namespace reference
template < typename T1, typename T2, typename std::enable_if < std::is_integral<T1>::value &&!std::is_same<T1, T2>::value, int >::type = 0 >
SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
-template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&!std::is_same<T1, T2>::value, int >::type = 0 >
+template < typename T1, typename T2, typename std::enable_if < is_floating_point<T1>::value &&(!std::is_same<T1, T2>::value &&!std::is_same<T2, bfloat16>::value), int >::type = 0 >
+SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
+
+template < typename T1, typename T2, typename std::enable_if < std::is_same<T1, bfloat16>::value || std::is_same<T2, bfloat16>::value, int >::type = 0 >
SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, ConvertPolicy policy, uint32_t shift);
} // namespace reference
} // namespace validation
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 50eb4753d1..79ec367a52 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -665,6 +665,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const DataType &data_type)
case DataType::S64:
os << "S64";
break;
+ case DataType::BFLOAT16:
+ os << "BFLOAT16";
+ break;
case DataType::F16:
os << "F16";
break;