diff options
Diffstat (limited to 'tests/validation/reference/GEMM.cpp')
-rw-r--r-- | tests/validation/reference/GEMM.cpp | 30 |
1 files changed, 18 insertions, 12 deletions
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp index 20f1139a02..d513343796 100644 --- a/tests/validation/reference/GEMM.cpp +++ b/tests/validation/reference/GEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021,2024 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -25,6 +25,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Types.h" +#include "tests/validation/reference/ArithmeticOperations.h" namespace arm_compute { @@ -180,17 +181,22 @@ SimpleTensor<T> gemm_mixed_precision( return dst; } -template SimpleTensor<float> -gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta); -template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a, - const SimpleTensor<bfloat16> &b, - const SimpleTensor<bfloat16> &c, - float alpha, - float beta); -template SimpleTensor<half> -gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); -template SimpleTensor<half> gemm_mixed_precision( - const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); +template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> +void gemm_accumulate(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta, SimpleTensor<T> &dst) +{ + // Compute reference + SimpleTensor<T> dst_gemm = gemm(a, b, c, alpha, beta); + reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, dst, dst_gemm, dst, ConvertPolicy::SATURATE); +} + +template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a, const SimpleTensor<bfloat16> &b, const SimpleTensor<bfloat16> &c, float alpha, float beta); +template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta); +template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); + +template void gemm_accumulate(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta, SimpleTensor<float> &dst); +template void gemm_accumulate(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta, SimpleTensor<half> &dst); + +template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); } // namespace reference } // namespace validation } // namespace test |