diff options
author | Renato Arantes <renato.arantes@arm.com> | 2024-01-26 17:31:18 +0000 |
---|---|---|
committer | Renato Barros Arantes <renato.arantes@arm.com> | 2024-03-21 11:15:30 +0000 |
commit | 36a75dafdbe6d6a3a6f50bd075fe01f5b7dace38 (patch) | |
tree | 0701d615ef30444b9d0789db691b59b81fd9e86e /src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | |
parent | d2191150736dde66d79eb97e0c8ee506eef3c8fc (diff) | |
download | ComputeLibrary-36a75dafdbe6d6a3a6f50bd075fe01f5b7dace38.tar.gz |
[ONCPUML-1451] Add matmul kernel to enable bf16 to bf16 operations via PyTorch® autocast() function
The full range of tests must be added with [MLINFSW-482] epic due to the lack of reordering kernels implemented in Acl.
Co-Authored-By: David Mansell <David.Mansell@arm.com>
Change-Id: I820d316295a1ec94fdc89c37e4144a268f914c36
Signed-off-by: Renato Arantes <renato.arantes@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11169
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 31 |
1 files changed, 24 insertions, 7 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 58ee68fd49..efe2a7a67e 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -581,7 +581,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) // Fixed format kernels need no pretranspose. ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format( assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); - const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size(); const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes()); @@ -857,6 +856,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg); + // TODO: Incorporate info.transpose_b COMPMID-6595 switch (a->data_type()) { @@ -900,9 +900,18 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected #if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: { - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), - "We could not find an optimized kernel for BFLOAT16 input and F32 output"); + if (d->data_type() == DataType::BFLOAT16) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), + "We could not find an optimized kernel for BFLOAT16 input and F32 output"); + } break; } #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ @@ -958,8 +967,9 @@ Status CpuGemmAssemblyDispatch::validate( "Only F32 output supported for F32 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, - "Only F32 output supported for BFLOAT16 input"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && + (d->data_type() != DataType::F32 && d->data_type() != DataType::BFLOAT16), + "Only F32/BFLOAT16 output supported for BFLOAT16 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, @@ -1030,7 +1040,14 @@ void CpuGemmAssemblyDispatch::configure( #endif /* __aarch64__ */ #if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: - create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); + if (d->data_type() == DataType::BFLOAT16) + { + create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info); + } + else + { + create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); + } break; #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC |