aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r--src/cpu/operators/CpuGemm.cpp14
1 files changed, 12 insertions, 2 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index a17e4f31d5..545d59f410 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -158,7 +158,17 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+
+ if (is_fixed_format_fast_math(gemm_info.weight_format()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+ }
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");