aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators')
-rw-r--r--src/cpu/operators/CpuMatMul.cpp28
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp31
2 files changed, 49 insertions, 10 deletions
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp
index 89087129c3..f68ae9883f 100644
--- a/src/cpu/operators/CpuMatMul.cpp
+++ b/src/cpu/operators/CpuMatMul.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -102,8 +102,8 @@ Status CpuMatMul::validate(const ITensorInfo *lhs,
const ActivationLayerInfo &act_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8,
- DataType::QASYMM8_SIGNED);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::BFLOAT16,
+ DataType::QASYMM8, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs->are_values_constant(), "LHS Tensor must be dynamic.");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs->are_values_constant(), "RHS Tensor must be dynamic.");
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs);
@@ -120,6 +120,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs,
auto gemm_info = AsmGemmInfo();
gemm_info.activation_info = act_info;
gemm_info.fast_mode = settings.fast_math();
+ gemm_info.fixed_format = settings.fixed_format();
// Validate and then permute a/b
if (adj_lhs)
@@ -157,6 +158,14 @@ Status CpuMatMul::validate(const ITensorInfo *lhs,
gemm_info.activation_info, gemm_info.output_stage));
}
+ if (gemm_info.fixed_format)
+ {
+ gemm_info.weight_format = WeightFormat::ANY;
+ arm_compute::WeightFormat expected_weight_format = WeightFormat::ANY;
+ ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, lhs_to_use,
+ rhs_to_use, nullptr, dst, gemm_info));
+ }
+
cpu::CpuGemmAssemblyDispatch::validate(lhs_to_use, rhs_to_use, nullptr, dst, gemm_info);
return Status{};
@@ -221,6 +230,7 @@ void CpuMatMul::configure(ITensorInfo *lhs,
// Fill AsmGemmInfo class object before configuration
_gemm_info.activation_info = act_info;
_gemm_info.fast_mode = settings.fast_math();
+ _gemm_info.fixed_format = settings.fixed_format();
_gemm_info.negated_offsets = false;
lhs_to_use = (_adj_lhs) ? _lhs_transposed : lhs_to_use;
@@ -233,6 +243,18 @@ void CpuMatMul::configure(ITensorInfo *lhs,
_gemm_info.output_stage);
}
+ if (_gemm_info.fixed_format)
+ {
+ _gemm_info.weight_format = WeightFormat::ANY;
+ arm_compute::WeightFormat expected_weight_format = WeightFormat::ANY;
+ ARM_COMPUTE_ERROR_THROW_ON(cpu::CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, &lhs_to_use,
+ &rhs_to_use, nullptr, dst, _gemm_info));
+ // Set gemm weights info to the one returned by has_opt_impl
+ _gemm_info.weight_format = expected_weight_format;
+ // has_opt_impl may return a non fast math kernel, even if we requested one
+ _gemm_info.fast_mode = arm_compute::is_fixed_format_fast_math(expected_weight_format);
+ }
+
// Configure Asm Kernel
_asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
_asm_glue->configure(&lhs_to_use, &rhs_to_use, nullptr, &dst_to_use,
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 58ee68fd49..efe2a7a67e 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -581,7 +581,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
-
const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
b_to_use->info()->offset_first_element_in_bytes());
@@ -857,6 +856,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
info.fixed_format, info.fast_mode, &cfg);
+
// TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
@@ -900,9 +900,18 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
#if defined(ARM_COMPUTE_ENABLE_BF16)
case DataType::BFLOAT16:
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ if (d->data_type() == DataType::BFLOAT16)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output");
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ }
break;
}
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
@@ -958,8 +967,9 @@ Status CpuGemmAssemblyDispatch::validate(
"Only F32 output supported for F32 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16,
"Only F16 output supported for F16 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32,
- "Only F32 output supported for BFLOAT16 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 &&
+ (d->data_type() != DataType::F32 && d->data_type() != DataType::BFLOAT16),
+ "Only F32/BFLOAT16 output supported for BFLOAT16 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32,
"Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
@@ -1030,7 +1040,14 @@ void CpuGemmAssemblyDispatch::configure(
#endif /* __aarch64__ */
#if defined(ARM_COMPUTE_ENABLE_BF16)
case DataType::BFLOAT16:
- create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ if (d->data_type() == DataType::BFLOAT16)
+ {
+ create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
+ }
+ else
+ {
+ create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ }
break;
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC