aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMLowpFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h18
1 files changed, 16 insertions, 2 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index d13ada9d64..0207f4c5ae 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -24,6 +24,7 @@
#ifndef ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE
#define ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
@@ -1012,13 +1013,19 @@ protected:
const unsigned int N = rhs_shape[0];
const unsigned int K = lhs_shape[0];
+ GEMMKernelInfo gemm_info;
+ gemm_info.m = M;
+ gemm_info.n = N;
+ gemm_info.k = K;
+ gemm_info.lhs_info = lhs_info;
+ gemm_info.rhs_info = rhs_info;
// The output tensor will be auto-initialized within the function
// Create and configure function
ReshapeRHSFunctionType reshape_rhs;
GEMMFunctionType gemm;
reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
- gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
+ gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info);
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -1148,13 +1155,20 @@ protected:
const unsigned int N = rhs_shape[0];
const unsigned int K = lhs_shape[0];
+ GEMMKernelInfo gemm_info;
+ gemm_info.m = M;
+ gemm_info.n = N;
+ gemm_info.k = K;
+ gemm_info.depth_output_gemm3d = m_h;
+ gemm_info.lhs_info = lhs_info;
+ gemm_info.rhs_info = rhs_info;
// The output tensor will be auto-initialized within the function
// Create and configure function
ReshapeRHSFunctionType reshape_rhs;
GEMMFunctionType gemm;
reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
- gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h));
+ gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info);
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);