diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 30 | ||||
-rw-r--r-- | tests/validation/fixtures/GEMMFixture.h | 12 |
2 files changed, 25 insertions, 17 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index 997c510e42..8d13cdac57 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -103,7 +103,7 @@ const auto act_values = framework::dataset::make("Activation", }); /** M0 values to test - Precommit */ -const auto m0_values_precommit = framework::dataset::make("M0", {4, 6}); +const auto m0_values_precommit = framework::dataset::make("M0", {4, 8}); /** N0 values to test - Precommit */ const auto n0_values_precommit = framework::dataset::make("N0", { 4 }); @@ -141,8 +141,11 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal /** Broadcast bias from vector to matrix */ const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } ); +/** LHS transposed values */ +const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } ); + /** Configuration test */ -void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, DataType data_type, const ActivationLayerInfo &act_info) +void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, bool lhs_transpose, DataType data_type, const ActivationLayerInfo &act_info) { const unsigned int M = m_value; const unsigned int N = n_value; @@ -153,14 +156,14 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned lhs_info.k0 = k0_value; lhs_info.v0 = v0_value; lhs_info.interleave = i_value_lhs; - lhs_info.transpose = false; + lhs_info.transpose = lhs_transpose; GEMMRHSMatrixInfo rhs_info; rhs_info.n0 = n0_value; rhs_info.k0 = k0_value; rhs_info.h0 = h0_value; rhs_info.interleave = i_value_rhs; - rhs_info.transpose = true; + rhs_info.transpose = !lhs_transpose; GEMMKernelInfo kernel_info; kernel_info.m = M; @@ -209,7 +212,7 @@ TEST_SUITE(CL) TEST_SUITE(GEMMMatrixMultiplyReshaped) TEST_SUITE(Float) TEST_SUITE(FP32) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -222,14 +225,15 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi i_values_lhs), i_values_rhs), broadcast_bias_values), + lhs_transpose_values), act_values), -m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, act_value) +m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, act_value) { - validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, DataType::F32, act_value); + validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, DataType::F32, act_value); } FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -245,6 +249,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra a_values), beta_values), broadcast_bias_values), + lhs_transpose_values), act_values)) { // Validate output @@ -252,7 +257,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra } FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -268,6 +273,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra a_values), beta_values), broadcast_bias_values), + lhs_transpose_values), act_values)) { // Validate output @@ -275,7 +281,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra } FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -291,6 +297,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::dataset::make("DataType", DataType::F32)), a_values), beta_values), + lhs_transpose_values), act_values)) { // Validate output @@ -298,7 +305,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, } FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -314,6 +321,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::dataset::make("DataType", DataType::F32)), a_values), beta_values), + lhs_transpose_values), act_values)) { // Validate output diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index a04a901b1c..854cc4a22b 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -673,21 +673,21 @@ class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture public: template <typename...> void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool interleave_lhs, - bool interleave_rhs, DataType data_type, float alpha, float beta, bool broadcast_bias, const ActivationLayerInfo &act_info) + bool interleave_rhs, DataType data_type, float alpha, float beta, bool broadcast_bias, bool lhs_transpose, const ActivationLayerInfo &act_info) { GEMMLHSMatrixInfo lhs_info; lhs_info.m0 = m0; lhs_info.k0 = k0; lhs_info.v0 = v0; lhs_info.interleave = interleave_lhs; - lhs_info.transpose = false; + lhs_info.transpose = lhs_transpose; GEMMRHSMatrixInfo rhs_info; rhs_info.n0 = n0; rhs_info.k0 = k0; rhs_info.h0 = h0; rhs_info.interleave = interleave_rhs; - rhs_info.transpose = true; + rhs_info.transpose = !lhs_transpose; // Set the tensor shapes for LHS and RHS matrices const TensorShape lhs_shape(k, m, batch_size); @@ -821,21 +821,21 @@ public: template <typename...> void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool interleave_lhs, - bool interleave_rhs, DataType data_type, float alpha, float beta, const ActivationLayerInfo &act_info) + bool interleave_rhs, DataType data_type, float alpha, float beta, bool lhs_transpose, const ActivationLayerInfo &act_info) { GEMMLHSMatrixInfo lhs_info; lhs_info.m0 = m0; lhs_info.k0 = k0; lhs_info.v0 = v0; lhs_info.interleave = interleave_lhs; - lhs_info.transpose = false; + lhs_info.transpose = lhs_transpose; GEMMRHSMatrixInfo rhs_info; rhs_info.n0 = n0; rhs_info.k0 = k0; rhs_info.h0 = h0; rhs_info.interleave = interleave_rhs; - rhs_info.transpose = true; + rhs_info.transpose = !lhs_transpose; // In case of GEMM3D, m is the product between m_w and m_h const unsigned int m = m_w * m_h; |