From c8cc024603cb1db084227196a52e562bf251d339 Mon Sep 17 00:00:00 2001 From: Ramy Elgammal Date: Wed, 5 Oct 2022 17:05:20 +0100 Subject: Adding documentation section explaining how BF16 is used Resolves: COMPMID-5494 Signed-off-by: Ramy Elgammal Change-Id: I8f512745855b8ca21181a9ab21323bfff6aeb866 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/458884 Tested-by: bsgcomp Reviewed-by: Viet-Hoa Do Comments-Addressed: bsgcomp Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8391 Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez Tello Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- docs/user_guide/library.dox | 9 +++++++++ src/cpu/operators/CpuGemm.h | 2 +- src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 6 +++--- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/user_guide/library.dox b/docs/user_guide/library.dox index 7a45fe9d9d..b95e0bace3 100644 --- a/docs/user_guide/library.dox +++ b/docs/user_guide/library.dox @@ -54,6 +54,15 @@ When the fast-math flag is enabled, both Arm® Neon™ and CL convolution layers - no-fast-math: No Winograd support - fast-math: Supports Winograd 3x3,3x1,1x3,5x1,1x5,7x1,1x7,5x5,7x7 +@section BF16 acceleration + +- Required toolchain: android-ndk-r23-beta5 or later +- To build for BF16: "neon" flag should be set "=1" and "arch" has to be "=armv8.6-a", "=armv8.6-a-sve", or "=armv8.6-a-sve2" using following command: +- scons arch=armv8.6-a-sve neon=1 opencl=0 extra_cxx_flags="-fPIC" benchmark_tests=0 validation_tests=0 validation_examples=1 os=android Werror=0 toolchain_prefix=aarch64-linux-android29 +- To enable BF16 acceleration when running FP32 "fast-math" has to be enabled and that works only for Neon convolution layer using cpu gemm. + In this scenario on CPU: the CpuGemmConv2d kernel performs the conversion from FP32, type of input tensor, to BF16 at block level to exploit the arithmetic capabilities dedicated to BF16. Then transforms back to FP32, the output + tensor type. + @section architecture_thread_safety Thread-safety Although the library supports multi-threading during workload dispatch, thus parallelizing the execution of the workload at multiple threads, the current runtime module implementation is not thread-safe in the sense of executing different functions from separate threads. diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h index 8d34b22437..031f02b3fd 100644 --- a/src/cpu/operators/CpuGemm.h +++ b/src/cpu/operators/CpuGemm.h @@ -76,7 +76,7 @@ public: * |:------------|:-----------|:---------|:--------------| * |F32 |F32 |F32 |F32 | * |F16 |F16 |F16 |F16 | - * |BFLOAT16 |BFLOAT16 |BFLOAT16 |BFLOAT16 | + * |BFLOAT16 |BFLOAT16 |BFLOAT16 |FP32 | * * @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C]. * @note GEMM: The tensors a, b, c, d must have the same data type. You should not mix data types when calling this function. diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 77da83070b..ab668681ad 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -716,7 +716,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected if(d->data_type() == DataType::S32) { ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), - "We could not find an optimized kernel for U8/QASYMM8 input and S32 output"); + "We could not find an optimized kernel for U8/QASYMM8 input and U32 output"); } else { @@ -734,7 +734,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected else { ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), - "We could not find an optimized kernel for S8 input and S32 output"); + "We could not find an optimized kernel for S8 input and S8 output"); } break; #endif /* __aarch64__ */ @@ -749,7 +749,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), - "We could not find an optimized kernel for BFLOAT16 input and F32 output"); + "We could not find an optimized kernel for F16 input and F16 output"); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: -- cgit v1.2.1