diff options
author | Renato Arantes <renato.arantes@arm.com> | 2024-01-26 17:31:18 +0000 |
---|---|---|
committer | Renato Barros Arantes <renato.arantes@arm.com> | 2024-03-21 11:15:30 +0000 |
commit | 36a75dafdbe6d6a3a6f50bd075fe01f5b7dace38 (patch) | |
tree | 0701d615ef30444b9d0789db691b59b81fd9e86e /src/cpu/operators/CpuMatMul.cpp | |
parent | d2191150736dde66d79eb97e0c8ee506eef3c8fc (diff) | |
download | ComputeLibrary-36a75dafdbe6d6a3a6f50bd075fe01f5b7dace38.tar.gz |
[ONCPUML-1451] Add matmul kernel to enable bf16 to bf16 operations via PyTorch® autocast() function
The full range of tests must be added with [MLINFSW-482] epic due to the lack of reordering kernels implemented in Acl.
Co-Authored-By: David Mansell <David.Mansell@arm.com>
Change-Id: I820d316295a1ec94fdc89c37e4144a268f914c36
Signed-off-by: Renato Arantes <renato.arantes@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11169
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuMatMul.cpp')
-rw-r--r-- | src/cpu/operators/CpuMatMul.cpp | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp index 89087129c3..f68ae9883f 100644 --- a/src/cpu/operators/CpuMatMul.cpp +++ b/src/cpu/operators/CpuMatMul.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -102,8 +102,8 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8, - DataType::QASYMM8_SIGNED); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::BFLOAT16, + DataType::QASYMM8, DataType::QASYMM8_SIGNED); ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs->are_values_constant(), "LHS Tensor must be dynamic."); ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs->are_values_constant(), "RHS Tensor must be dynamic."); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs); @@ -120,6 +120,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, auto gemm_info = AsmGemmInfo(); gemm_info.activation_info = act_info; gemm_info.fast_mode = settings.fast_math(); + gemm_info.fixed_format = settings.fixed_format(); // Validate and then permute a/b if (adj_lhs) @@ -157,6 +158,14 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, gemm_info.activation_info, gemm_info.output_stage)); } + if (gemm_info.fixed_format) + { + gemm_info.weight_format = WeightFormat::ANY; + arm_compute::WeightFormat expected_weight_format = WeightFormat::ANY; + ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, lhs_to_use, + rhs_to_use, nullptr, dst, gemm_info)); + } + cpu::CpuGemmAssemblyDispatch::validate(lhs_to_use, rhs_to_use, nullptr, dst, gemm_info); return Status{}; @@ -221,6 +230,7 @@ void CpuMatMul::configure(ITensorInfo *lhs, // Fill AsmGemmInfo class object before configuration _gemm_info.activation_info = act_info; _gemm_info.fast_mode = settings.fast_math(); + _gemm_info.fixed_format = settings.fixed_format(); _gemm_info.negated_offsets = false; lhs_to_use = (_adj_lhs) ? _lhs_transposed : lhs_to_use; @@ -233,6 +243,18 @@ void CpuMatMul::configure(ITensorInfo *lhs, _gemm_info.output_stage); } + if (_gemm_info.fixed_format) + { + _gemm_info.weight_format = WeightFormat::ANY; + arm_compute::WeightFormat expected_weight_format = WeightFormat::ANY; + ARM_COMPUTE_ERROR_THROW_ON(cpu::CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, &lhs_to_use, + &rhs_to_use, nullptr, dst, _gemm_info)); + // Set gemm weights info to the one returned by has_opt_impl + _gemm_info.weight_format = expected_weight_format; + // has_opt_impl may return a non fast math kernel, even if we requested one + _gemm_info.fast_mode = arm_compute::is_fixed_format_fast_math(expected_weight_format); + } + // Configure Asm Kernel _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>(); _asm_glue->configure(&lhs_to_use, &rhs_to_use, nullptr, &dst_to_use, |