aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.cpp
diff options
context:
space:
mode:
authorJonathan Deakin <jonathan.deakin@arm.com>2023-01-12 11:41:14 +0000
committerJonathan Deakin <jonathan.deakin@arm.com>2023-02-01 08:05:35 +0000
commit464ed2087c2ce2d2e741cc1e1dc4bd49d06e7d26 (patch)
treeda07a18be246742773a729e264080d9a9b314d59 /src/cpu/operators/CpuGemm.cpp
parent7594f989963724e127c3e28210d60fed590b0524 (diff)
downloadComputeLibrary-464ed2087c2ce2d2e741cc1e1dc4bd49d06e7d26.tar.gz
Remove fixed format strides hack
- Remove hack in CpuGemmAssemblyDispatch.cpp which tried to guess strides for fixed format kernels. Instead, expect that strides will have been correctly set on weights externally - Update fixed format test fixtures to set the strides - If the fixed format uses fast math mode, then weights should be of type BFLOAT16. Change the validation logic to accept this. Resolves: [ONCPUML-1131] Co-authored-by: Milos Puzovic <Milos.Puzovic@arm.com> Change-Id: I0f18d8b86b0f639be25fd122fa06a591e90645f2 Signed-off-by: Jonathan Deakin <jonathan.deakin@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8985 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r--src/cpu/operators/CpuGemm.cpp14
1 files changed, 12 insertions, 2 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index a17e4f31d5..545d59f410 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -158,7 +158,17 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+
+ if (is_fixed_format_fast_math(gemm_info.weight_format()))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
+ }
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");