diff options
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r-- | src/cpu/operators/CpuGemm.cpp | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index a17e4f31d5..545d59f410 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -158,7 +158,17 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); + + if (is_fixed_format_fast_math(gemm_info.weight_format())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); + } + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported"); |