aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/GEMMFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMFixture.h60
1 files changed, 49 insertions, 11 deletions
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index afde3d8067..94bedc83e1 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,14 +46,14 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
-class GEMMValidationFixture : public framework::Fixture
+class GEMMGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type, bool accumulate=false)
{
ARM_COMPUTE_UNUSED(pretranspose);
- _target = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type);
- _reference = compute_reference(shape_a, shape_b, output_shape, alpha, beta, data_type);
+ _target = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type, accumulate);
+ _reference = compute_reference(shape_a, shape_b, output_shape, alpha, beta, data_type, accumulate);
}
protected:
@@ -80,7 +80,7 @@ protected:
}
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &output_shape, float alpha, float beta,
- DataType data_type)
+ DataType data_type, bool accumulate=false)
{
// Create tensors
TensorType a = create_tensor<TensorType>(shape_a, data_type, 1);
@@ -99,7 +99,7 @@ protected:
&dst,
alpha, beta,
GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, false, (reinterpret_input_as_3d
- || reinterpret_output_as_3d)));
+ || reinterpret_output_as_3d), arm_compute::ActivationLayerInfo(), false /* fixed_format */, arm_compute::WeightFormat::UNSPECIFIED, false /* pretranspose_B */, accumulate));
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
ARM_COMPUTE_ASSERT(c.info()->is_resizable());
@@ -121,11 +121,14 @@ protected:
// Fill tensors
fill(AccessorType(a), 0);
fill(AccessorType(b), 1);
+ if (accumulate)
+ {
+ fill(AccessorType(dst), 6);
+ }
if(!disable_c)
{
fill(AccessorType(c), 2);
}
-
// Run with variable inputs.
if(run_twice)
{
@@ -145,7 +148,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, float alpha, float beta,
- DataType data_type)
+ DataType data_type, bool accumulate=false)
{
TensorShape shape_a_to_use = shape_a;
if(reinterpret_input_as_3d)
@@ -158,6 +161,7 @@ protected:
SimpleTensor<T> a{ shape_a_to_use, data_type, 1 };
SimpleTensor<T> b{ shape_b, data_type, 1 };
SimpleTensor<T> c{ output_shape, data_type, 1 };
+ SimpleTensor<T> dst{ output_shape, data_type, 1 };
// Fill reference
fill(a, 0);
@@ -211,17 +215,51 @@ protected:
fill(c, 5);
}
+ // Do in place summation
+ if (accumulate)
+ {
+ fill(dst, 6);
+ }
+
// Setting beta to 0 will effectively disable C for the
// computation of the reference: alpha * A * B + 0 * C
// Use transposed tensors if boolean enabled else use original tensors
- auto r = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
- return r;
+ if (accumulate)
+ {
+ reference::gemm_accumulate<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta, dst);
+ return dst;
+ }
+ else
+ {
+ return reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
+ }
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
+class GEMMValidationFixture : protected GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ {
+ GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type, false /*accumulate*/);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
+class GEMMAccumulateValidationFixture : protected GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ {
+ bool accumulate = true;
+ GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type, accumulate);
+ }
+};
+
template <typename TensorType, typename AccessorType, typename T, typename GEMMOperatorType>
class GEMMMatrixMultiplyValidationFixture : public framework::Fixture
{