diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp index cbbc5922dd..83051d2efe 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp @@ -103,7 +103,7 @@ const auto k0_values_precommit = framework::dataset::make("K0", { 4 }); const auto h0_values_precommit = framework::dataset::make("H0", 1, 3); /** M0 values to test - Nightly */ -const auto m0_values_nightly = framework::dataset::make("M0", 2, 8); +const auto m0_values_nightly = framework::dataset::make("M0", 1, 8); /** N0 values to test - Nightly */ const auto n0_values_nightly = framework::dataset::make("N0", { 2, 3, 4, 8 }); @@ -118,10 +118,10 @@ const auto h0_values_nightly = framework::dataset::make("H0", 1, 4); const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false }); /** Transpose values to test with RHS matrix */ -const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true }); +const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, false }); /** Configuration test */ -void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value, bool i_value_rhs, DataType data_type) +void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value, bool i_value_rhs, bool t_value_rhs, DataType data_type) { const unsigned int M = m_value; const unsigned int N = n_value; @@ -136,7 +136,7 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned rhs_info.k0 = k0_value; rhs_info.h0 = h0_value; rhs_info.interleave = i_value_rhs; - rhs_info.transpose = true; + rhs_info.transpose = t_value_rhs; GEMMReshapeInfo gemm_info(M, N, K); @@ -168,7 +168,7 @@ TEST_SUITE(CL) TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRHS) TEST_SUITE(Float) TEST_SUITE(FP32) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine( +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -178,9 +178,10 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi k0_values_precommit), h0_values_precommit), i_values_rhs), -m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs) + t_values_rhs), +m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, t_value_rhs) { - validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, DataType::F32); + validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, t_value_rhs, DataType::F32); } FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::ALL, |