diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-02-26 09:58:13 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-03-05 15:15:15 +0000 |
commit | e8291acc1d9e89c9274d31f0d5bb4779eb95588c (patch) | |
tree | 5a0fef36d6daabe387174e55b60de54557c75291 /src/core | |
parent | aa85cdf22802cb892d7fa422ca505a43d84adb38 (diff) | |
download | ComputeLibrary-e8291acc1d9e89c9274d31f0d5bb4779eb95588c.tar.gz |
COMPMID-3152: Initial Bfloat16 support
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: Ie6959e37e13731c86b2ee29392a99a293450a1b4
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2824
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp | 93 | ||||
-rw-r--r-- | src/core/Utils.cpp | 13 |
2 files changed, 96 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp index f824f7ac58..79dc2cb585 100644 --- a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp +++ b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp @@ -33,7 +33,7 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/SaturateCast.h" -#include <arm_neon.h> +#include "arm_compute/core/NEON/wrapper/wrapper.h" using namespace arm_compute; @@ -43,11 +43,16 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, C { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(output); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(input); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(output); ARM_COMPUTE_UNUSED(policy); ARM_COMPUTE_RETURN_ERROR_ON(input == output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::F16, DataType::F32, DataType::S32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16, - DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, + DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, + DataType::F32, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, + DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16, + DataType::U32, DataType::S32, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON(shift >= 8); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QASYMM8_SIGNED && (output->data_type() != DataType::S16 && output->data_type() != DataType::S32 @@ -68,15 +73,18 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, C ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S16 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::U8 && output->data_type() != DataType::S32), "Only data_types supported [in] S16 -> [out] U8, S32"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::BFLOAT16 && output->data_type() != DataType::F32, + "Only data_types supported [in] BFLOAT16 -> [out] F32"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F16 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8 && output->data_type() != DataType::U8 && output->data_type() != DataType::F32 && output->data_type() != DataType::S32), "Only data_types supported [in] F16 -> [out] QASYMM8, F32, S32, U8"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F32 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8 - && output->data_type() != DataType::F16 + && output->data_type() != DataType::F16 && output->data_type() != DataType::BFLOAT16 && output->data_type() != DataType::S32 && output->data_type() != DataType::U8), - "Only data_types supported [in] F32 -> [out] QASYMM8, F16, S32, U8"); + "Only data_types supported [in] F32 -> [out] QASYMM8, BFLOAT16, F16, S32, U8"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8 && output->data_type() != DataType::F16 @@ -786,6 +794,52 @@ void NEDepthConvertLayerKernel::run(const Window &window, const ThreadInfo &info } break; } +#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) + case DataType::BFLOAT16: + switch(_output->info()->data_type()) + { + case DataType::F32: + { + /* Up-conversion BFLOAT16 -> F32 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<const bfloat16 *>(input.ptr()); + const auto output_ptr = reinterpret_cast<float *>(output.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const uint16x8x2_t texels = + { + { + vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())), + vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8) + } + }; + + vst1q_f32(reinterpret_cast<float *>(output.ptr()), + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16))); + vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16))); + vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16))); + vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12, + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16))); + } + + for(; x < window_end_x; ++x) + { + *(output_ptr + x) = float(*(input_ptr + x)); + } + }, + input, output); + break; + } + default: + ARM_COMPUTE_ERROR("Output data type unsupported"); + } + break; +#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: switch(_output->info()->data_type()) @@ -980,6 +1034,33 @@ void NEDepthConvertLayerKernel::run(const Window &window, const ThreadInfo &info break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) + case DataType::BFLOAT16: + { + /* Down-conversion F32 -> BFLOAT16 */ + execute_window_loop(win, [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<const float *>(input.ptr()); + const auto output_ptr = reinterpret_cast<bfloat16 *>(output.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(input.ptr()), + reinterpret_cast<uint16_t *>(output.ptr())); + wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(input.ptr()) + 8, + reinterpret_cast<uint16_t *>(output.ptr()) + 8); + } + + for(; x < window_end_x; ++x) + { + *(output_ptr + x) = *(input_ptr + x); + } + }, + input, output); + break; + } +#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ case DataType::S32: { const float scale_s = 1.f / (1 << _shift); diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp index bca5a31914..fb86d78cb7 100644 --- a/src/core/Utils.cpp +++ b/src/core/Utils.cpp @@ -510,12 +510,15 @@ void arm_compute::print_consecutive_elements(std::ostream &s, DataType dt, const case DataType::S32: print_consecutive_elements_impl<int32_t>(s, reinterpret_cast<const int32_t *>(ptr), n, stream_width, element_delim); break; - case DataType::F32: - print_consecutive_elements_impl<float>(s, reinterpret_cast<const float *>(ptr), n, stream_width, element_delim); + case DataType::BFLOAT16: + print_consecutive_elements_impl<bfloat16>(s, reinterpret_cast<const bfloat16 *>(ptr), n, stream_width, element_delim); break; case DataType::F16: print_consecutive_elements_impl<half>(s, reinterpret_cast<const half *>(ptr), n, stream_width, element_delim); break; + case DataType::F32: + print_consecutive_elements_impl<float>(s, reinterpret_cast<const float *>(ptr), n, stream_width, element_delim); + break; default: ARM_COMPUTE_ERROR("Undefined element size for given data type"); } @@ -542,10 +545,12 @@ int arm_compute::max_consecutive_elements_display_width(std::ostream &s, DataTyp return max_consecutive_elements_display_width_impl<uint32_t>(s, reinterpret_cast<const uint32_t *>(ptr), n); case DataType::S32: return max_consecutive_elements_display_width_impl<int32_t>(s, reinterpret_cast<const int32_t *>(ptr), n); - case DataType::F32: - return max_consecutive_elements_display_width_impl<float>(s, reinterpret_cast<const float *>(ptr), n); + case DataType::BFLOAT16: + return max_consecutive_elements_display_width_impl<bfloat16>(s, reinterpret_cast<const bfloat16 *>(ptr), n); case DataType::F16: return max_consecutive_elements_display_width_impl<half>(s, reinterpret_cast<const half *>(ptr), n); + case DataType::F32: + return max_consecutive_elements_display_width_impl<float>(s, reinterpret_cast<const float *>(ptr), n); default: ARM_COMPUTE_ERROR("Undefined element size for given data type"); } |