diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2019-08-01 14:22:12 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-09-02 15:31:29 +0000 |
commit | ae99b6eac40c7c3cb5eb465f3cbe4b522eff0488 (patch) | |
tree | 1cd14abfd10953686185d3697c545830f26ac7bb /tests/validation/CL | |
parent | c2a60593436387d20ff142a619f4c3955a5cd41b (diff) | |
download | ComputeLibrary-ae99b6eac40c7c3cb5eb465f3cbe4b522eff0488.tar.gz |
COMPMID-1965 Extend CLGEMMMatrixMultiplyReshapedKernel to support transposed LHS (t) and not-transpose RHS
Change-Id: I437a00d7213fefd6f4365071b46174d44df8b85c
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1677
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL')
-rw-r--r-- | tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index 997c510e42..8d13cdac57 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -103,7 +103,7 @@ const auto act_values = framework::dataset::make("Activation", }); /** M0 values to test - Precommit */ -const auto m0_values_precommit = framework::dataset::make("M0", {4, 6}); +const auto m0_values_precommit = framework::dataset::make("M0", {4, 8}); /** N0 values to test - Precommit */ const auto n0_values_precommit = framework::dataset::make("N0", { 4 }); @@ -141,8 +141,11 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal /** Broadcast bias from vector to matrix */ const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } ); +/** LHS transposed values */ +const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } ); + /** Configuration test */ -void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, DataType data_type, const ActivationLayerInfo &act_info) +void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, bool lhs_transpose, DataType data_type, const ActivationLayerInfo &act_info) { const unsigned int M = m_value; const unsigned int N = n_value; @@ -153,14 +156,14 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned lhs_info.k0 = k0_value; lhs_info.v0 = v0_value; lhs_info.interleave = i_value_lhs; - lhs_info.transpose = false; + lhs_info.transpose = lhs_transpose; GEMMRHSMatrixInfo rhs_info; rhs_info.n0 = n0_value; rhs_info.k0 = k0_value; rhs_info.h0 = h0_value; rhs_info.interleave = i_value_rhs; - rhs_info.transpose = true; + rhs_info.transpose = !lhs_transpose; GEMMKernelInfo kernel_info; kernel_info.m = M; @@ -209,7 +212,7 @@ TEST_SUITE(CL) TEST_SUITE(GEMMMatrixMultiplyReshaped) TEST_SUITE(Float) TEST_SUITE(FP32) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -222,14 +225,15 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi i_values_lhs), i_values_rhs), broadcast_bias_values), + lhs_transpose_values), act_values), -m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, act_value) +m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, act_value) { - validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, DataType::F32, act_value); + validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, DataType::F32, act_value); } FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -245,6 +249,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra a_values), beta_values), broadcast_bias_values), + lhs_transpose_values), act_values)) { // Validate output @@ -252,7 +257,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra } FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -268,6 +273,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra a_values), beta_values), broadcast_bias_values), + lhs_transpose_values), act_values)) { // Validate output @@ -275,7 +281,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra } FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -291,6 +297,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::dataset::make("DataType", DataType::F32)), a_values), beta_values), + lhs_transpose_values), act_values)) { // Validate output @@ -298,7 +305,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, } FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -314,6 +321,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::dataset::make("DataType", DataType::F32)), a_values), beta_values), + lhs_transpose_values), act_values)) { // Validate output |