aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-02-26 09:58:13 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-03-05 15:15:15 +0000
commite8291acc1d9e89c9274d31f0d5bb4779eb95588c (patch)
tree5a0fef36d6daabe387174e55b60de54557c75291 /src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
parentaa85cdf22802cb892d7fa422ca505a43d84adb38 (diff)
downloadComputeLibrary-e8291acc1d9e89c9274d31f0d5bb4779eb95588c.tar.gz
COMPMID-3152: Initial Bfloat16 support
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: Ie6959e37e13731c86b2ee29392a99a293450a1b4 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2824 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp93
1 files changed, 87 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
index f824f7ac58..79dc2cb585 100644
--- a/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthConvertLayerKernel.cpp
@@ -33,7 +33,7 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/SaturateCast.h"
-#include <arm_neon.h>
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
using namespace arm_compute;
@@ -43,11 +43,16 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, C
{
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(output);
ARM_COMPUTE_UNUSED(policy);
ARM_COMPUTE_RETURN_ERROR_ON(input == output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::F16, DataType::F32, DataType::S32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16,
- DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
+ DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::F32, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::U8,
+ DataType::S16, DataType::U16, DataType::BFLOAT16, DataType::F16,
+ DataType::U32, DataType::S32, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(shift >= 8);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QASYMM8_SIGNED && (output->data_type() != DataType::S16 && output->data_type() != DataType::S32
@@ -68,15 +73,18 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, C
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S16 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::U8 && output->data_type() != DataType::S32),
"Only data_types supported [in] S16 -> [out] U8, S32");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::BFLOAT16 && output->data_type() != DataType::F32,
+ "Only data_types supported [in] BFLOAT16 -> [out] F32");
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F16 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
&& output->data_type() != DataType::U8
&& output->data_type() != DataType::F32 && output->data_type() != DataType::S32),
"Only data_types supported [in] F16 -> [out] QASYMM8, F32, S32, U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::F32 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
- && output->data_type() != DataType::F16
+ && output->data_type() != DataType::F16 && output->data_type() != DataType::BFLOAT16
&& output->data_type() != DataType::S32 && output->data_type() != DataType::U8),
- "Only data_types supported [in] F32 -> [out] QASYMM8, F16, S32, U8");
+ "Only data_types supported [in] F32 -> [out] QASYMM8, BFLOAT16, F16, S32, U8");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && (output->data_type() != DataType::QASYMM8_SIGNED && output->data_type() != DataType::QASYMM8
&& output->data_type() != DataType::F16
@@ -786,6 +794,52 @@ void NEDepthConvertLayerKernel::run(const Window &window, const ThreadInfo &info
}
break;
}
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+ case DataType::BFLOAT16:
+ switch(_output->info()->data_type())
+ {
+ case DataType::F32:
+ {
+ /* Up-conversion BFLOAT16 -> F32 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<const bfloat16 *>(input.ptr());
+ const auto output_ptr = reinterpret_cast<float *>(output.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint16x8x2_t texels =
+ {
+ {
+ vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr())),
+ vld1q_u16(reinterpret_cast<uint16_t *>(input.ptr()) + 8)
+ }
+ };
+
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()),
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[0])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 4,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[0])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 8,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_low_u16(texels.val[1])), 16)));
+ vst1q_f32(reinterpret_cast<float *>(output.ptr()) + 12,
+ vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vget_high_u16(texels.val[1])), 16)));
+ }
+
+ for(; x < window_end_x; ++x)
+ {
+ *(output_ptr + x) = float(*(input_ptr + x));
+ }
+ },
+ input, output);
+ break;
+ }
+ default:
+ ARM_COMPUTE_ERROR("Output data type unsupported");
+ }
+ break;
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
switch(_output->info()->data_type())
@@ -980,6 +1034,33 @@ void NEDepthConvertLayerKernel::run(const Window &window, const ThreadInfo &info
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+ case DataType::BFLOAT16:
+ {
+ /* Down-conversion F32 -> BFLOAT16 */
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<const float *>(input.ptr());
+ const auto output_ptr = reinterpret_cast<bfloat16 *>(output.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(input.ptr()),
+ reinterpret_cast<uint16_t *>(output.ptr()));
+ wrapper::vcvt_bf16_f32(reinterpret_cast<float *>(input.ptr()) + 8,
+ reinterpret_cast<uint16_t *>(output.ptr()) + 8);
+ }
+
+ for(; x < window_end_x; ++x)
+ {
+ *(output_ptr + x) = *(input_ptr + x);
+ }
+ },
+ input, output);
+ break;
+ }
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
case DataType::S32:
{
const float scale_s = 1.f / (1 << _shift);