aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorRenato Arantes <renato.arantes@arm.com>2024-01-26 17:31:18 +0000
committerRenato Barros Arantes <renato.arantes@arm.com>2024-03-21 11:15:30 +0000
commit36a75dafdbe6d6a3a6f50bd075fe01f5b7dace38 (patch)
tree0701d615ef30444b9d0789db691b59b81fd9e86e /src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
parentd2191150736dde66d79eb97e0c8ee506eef3c8fc (diff)
downloadComputeLibrary-36a75dafdbe6d6a3a6f50bd075fe01f5b7dace38.tar.gz
[ONCPUML-1451] Add matmul kernel to enable bf16 to bf16 operations via PyTorch® autocast() function
The full range of tests must be added with [MLINFSW-482] epic due to the lack of reordering kernels implemented in Acl. Co-Authored-By: David Mansell <David.Mansell@arm.com> Change-Id: I820d316295a1ec94fdc89c37e4144a268f914c36 Signed-off-by: Renato Arantes <renato.arantes@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11169 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp31
1 files changed, 24 insertions, 7 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 58ee68fd49..efe2a7a67e 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -581,7 +581,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
-
const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
b_to_use->info()->offset_first_element_in_bytes());
@@ -857,6 +856,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
info.fixed_format, info.fast_mode, &cfg);
+
// TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
@@ -900,9 +900,18 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
#if defined(ARM_COMPUTE_ENABLE_BF16)
case DataType::BFLOAT16:
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ if (d->data_type() == DataType::BFLOAT16)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output");
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ }
break;
}
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
@@ -958,8 +967,9 @@ Status CpuGemmAssemblyDispatch::validate(
"Only F32 output supported for F32 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16,
"Only F16 output supported for F16 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32,
- "Only F32 output supported for BFLOAT16 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 &&
+ (d->data_type() != DataType::F32 && d->data_type() != DataType::BFLOAT16),
+ "Only F32/BFLOAT16 output supported for BFLOAT16 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32,
"Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
@@ -1030,7 +1040,14 @@ void CpuGemmAssemblyDispatch::configure(
#endif /* __aarch64__ */
#if defined(ARM_COMPUTE_ENABLE_BF16)
case DataType::BFLOAT16:
- create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ if (d->data_type() == DataType::BFLOAT16)
+ {
+ create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
+ }
+ else
+ {
+ create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ }
break;
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC