aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/Utils.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-02-26 09:58:13 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-03-05 15:15:15 +0000
commite8291acc1d9e89c9274d31f0d5bb4779eb95588c (patch)
tree5a0fef36d6daabe387174e55b60de54557c75291 /arm_compute/core/Utils.h
parentaa85cdf22802cb892d7fa422ca505a43d84adb38 (diff)
downloadComputeLibrary-e8291acc1d9e89c9274d31f0d5bb4779eb95588c.tar.gz
COMPMID-3152: Initial Bfloat16 support
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: Ie6959e37e13731c86b2ee29392a99a293450a1b4 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2824 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'arm_compute/core/Utils.h')
-rw-r--r--arm_compute/core/Utils.h27
1 files changed, 27 insertions, 0 deletions
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 4a3b01d21f..8577046af0 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -114,6 +114,7 @@ inline size_t data_size_from_type(DataType data_type)
case DataType::S16:
case DataType::QSYMM16:
case DataType::QASYMM16:
+ case DataType::BFLOAT16:
case DataType::F16:
return 2;
case DataType::F32:
@@ -146,6 +147,7 @@ inline size_t pixel_size_from_format(Format format)
return 1;
case Format::U16:
case Format::S16:
+ case Format::BFLOAT16:
case Format::F16:
case Format::UV88:
case Format::YUYV422:
@@ -191,6 +193,7 @@ inline size_t element_size_from_data_type(DataType dt)
case DataType::S16:
case DataType::QSYMM16:
case DataType::QASYMM16:
+ case DataType::BFLOAT16:
case DataType::F16:
return 2;
case DataType::U32:
@@ -228,6 +231,8 @@ inline DataType data_type_from_format(Format format)
return DataType::U32;
case Format::S32:
return DataType::S32;
+ case Format::BFLOAT16:
+ return DataType::BFLOAT16;
case Format::F16:
return DataType::F16;
case Format::F32:
@@ -260,6 +265,7 @@ inline int plane_idx_from_channel(Format format, Channel channel)
case Format::S16:
case Format::U32:
case Format::S32:
+ case Format::BFLOAT16:
case Format::F16:
case Format::F32:
case Format::UV88:
@@ -447,6 +453,7 @@ inline size_t num_planes_from_format(Format format)
case Format::U16:
case Format::S32:
case Format::U32:
+ case Format::BFLOAT16:
case Format::F16:
case Format::F32:
case Format::RGB888:
@@ -481,6 +488,7 @@ inline size_t num_channels_from_format(Format format)
case Format::S16:
case Format::U32:
case Format::S32:
+ case Format::BFLOAT16:
case Format::F16:
case Format::F32:
return 1;
@@ -531,6 +539,7 @@ inline DataType get_promoted_data_type(DataType dt)
case DataType::QSYMM8_PER_CHANNEL:
case DataType::QSYMM16:
case DataType::QASYMM16:
+ case DataType::BFLOAT16:
case DataType::F16:
case DataType::U32:
case DataType::S32:
@@ -596,6 +605,12 @@ inline std::tuple<PixelValue, PixelValue> get_min_max(DataType dt)
max = PixelValue(std::numeric_limits<int32_t>::max());
break;
}
+ case DataType::BFLOAT16:
+ {
+ min = PixelValue(bfloat16::lowest());
+ max = PixelValue(bfloat16::max());
+ break;
+ }
case DataType::F16:
{
min = PixelValue(std::numeric_limits<half>::lowest());
@@ -1284,6 +1299,8 @@ bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = Quantization
const auto val_s32 = static_cast<int32_t>(val);
return ((val_s32 == val) && val_s32 >= std::numeric_limits<int32_t>::lowest() && val_s32 <= std::numeric_limits<int32_t>::max());
}
+ case DataType::BFLOAT16:
+ return (val >= bfloat16::lowest() && val <= bfloat16::max());
case DataType::F16:
return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max());
case DataType::F32:
@@ -1323,6 +1340,11 @@ void print_consecutive_elements_impl(std::ostream &s, const T *ptr, unsigned int
// We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
s << std::right << static_cast<T>(ptr[i]) << element_delim;
}
+ else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
+ {
+ // We use T instead of print_type here is because the std::is_floating_point<bfloat> returns false and then the print_type becomes int.
+ s << std::right << float(ptr[i]) << element_delim;
+ }
else
{
s << std::right << static_cast<print_type>(ptr[i]) << element_delim;
@@ -1357,6 +1379,11 @@ int max_consecutive_elements_display_width_impl(std::ostream &s, const T *ptr, u
// We use T instead of print_type here is because the std::is_floating_point<half> returns false and then the print_type becomes int.
ss << static_cast<T>(ptr[i]);
}
+ else if(std::is_same<typename std::decay<T>::type, bfloat16>::value)
+ {
+ // We use T instead of print_type here is because the std::is_floating_point<bfloat> returns false and then the print_type becomes int.
+ ss << float(ptr[i]);
+ }
else
{
ss << static_cast<print_type>(ptr[i]);