diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp | 11 | ||||
-rw-r--r-- | tests/validation/fixtures/GEMMLowpFixture.h | 18 |
2 files changed, 24 insertions, 5 deletions
diff --git a/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp index 30c91b7091..f4083bfd95 100644 --- a/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp +++ b/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -130,7 +130,12 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned rhs_info.interleave = i_value_rhs; rhs_info.transpose = true; - GEMMReshapeInfo gemm_info(M, N, K); + GEMMKernelInfo gemm_info; + gemm_info.m = M; + gemm_info.n = N; + gemm_info.k = K; + gemm_info.lhs_info = lhs_info; + gemm_info.rhs_info = rhs_info; const TensorShape lhs_shape(K, M, b_value); const TensorShape rhs_shape(N, K, b_value); @@ -152,7 +157,7 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned // Create and configure function CLGEMMLowpMatrixMultiplyReshapedOnlyRHS gemm; - gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, gemm_info); + gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info); } } // namespace diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index d13ada9d64..0207f4c5ae 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -24,6 +24,7 @@ #ifndef ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE #define ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE +#include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" @@ -1012,13 +1013,19 @@ protected: const unsigned int N = rhs_shape[0]; const unsigned int K = lhs_shape[0]; + GEMMKernelInfo gemm_info; + gemm_info.m = M; + gemm_info.n = N; + gemm_info.k = K; + gemm_info.lhs_info = lhs_info; + gemm_info.rhs_info = rhs_info; // The output tensor will be auto-initialized within the function // Create and configure function ReshapeRHSFunctionType reshape_rhs; GEMMFunctionType gemm; reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info); - gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K)); + gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info); ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -1148,13 +1155,20 @@ protected: const unsigned int N = rhs_shape[0]; const unsigned int K = lhs_shape[0]; + GEMMKernelInfo gemm_info; + gemm_info.m = M; + gemm_info.n = N; + gemm_info.k = K; + gemm_info.depth_output_gemm3d = m_h; + gemm_info.lhs_info = lhs_info; + gemm_info.rhs_info = rhs_info; // The output tensor will be auto-initialized within the function // Create and configure function ReshapeRHSFunctionType reshape_rhs; GEMMFunctionType gemm; reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info); - gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h)); + gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info); ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); |