aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/GEMM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/GEMM.cpp')
-rw-r--r--tests/validation/reference/GEMM.cpp30
1 files changed, 18 insertions, 12 deletions
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp
index 20f1139a02..d513343796 100644
--- a/tests/validation/reference/GEMM.cpp
+++ b/tests/validation/reference/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021,2024 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,7 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
namespace arm_compute
{
@@ -180,17 +181,22 @@ SimpleTensor<T> gemm_mixed_precision(
return dst;
}
-template SimpleTensor<float>
-gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
-template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a,
- const SimpleTensor<bfloat16> &b,
- const SimpleTensor<bfloat16> &c,
- float alpha,
- float beta);
-template SimpleTensor<half>
-gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
-template SimpleTensor<half> gemm_mixed_precision(
- const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
+void gemm_accumulate(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta, SimpleTensor<T> &dst)
+{
+ // Compute reference
+ SimpleTensor<T> dst_gemm = gemm(a, b, c, alpha, beta);
+ reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, dst, dst_gemm, dst, ConvertPolicy::SATURATE);
+}
+
+template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a, const SimpleTensor<bfloat16> &b, const SimpleTensor<bfloat16> &c, float alpha, float beta);
+template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
+template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+
+template void gemm_accumulate(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta, SimpleTensor<float> &dst);
+template void gemm_accumulate(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta, SimpleTensor<half> &dst);
+
+template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
} // namespace reference
} // namespace validation
} // namespace test