From ae99b6eac40c7c3cb5eb465f3cbe4b522eff0488 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Thu, 1 Aug 2019 14:22:12 +0100 Subject: COMPMID-1965 Extend CLGEMMMatrixMultiplyReshapedKernel to support transposed LHS (t) and not-transpose RHS Change-Id: I437a00d7213fefd6f4365071b46174d44df8b85c Signed-off-by: Giorgio Arena Reviewed-on: https://review.mlplatform.org/c/1677 Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins --- tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 30 ++++++++++++++-------- 1 file changed, 19 insertions(+), 11 deletions(-) (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp') diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index 997c510e42..8d13cdac57 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -103,7 +103,7 @@ const auto act_values = framework::dataset::make("Activation", }); /** M0 values to test - Precommit */ -const auto m0_values_precommit = framework::dataset::make("M0", {4, 6}); +const auto m0_values_precommit = framework::dataset::make("M0", {4, 8}); /** N0 values to test - Precommit */ const auto n0_values_precommit = framework::dataset::make("N0", { 4 }); @@ -141,8 +141,11 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal /** Broadcast bias from vector to matrix */ const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } ); +/** LHS transposed values */ +const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } ); + /** Configuration test */ -void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, DataType data_type, const ActivationLayerInfo &act_info) +void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, bool lhs_transpose, DataType data_type, const ActivationLayerInfo &act_info) { const unsigned int M = m_value; const unsigned int N = n_value; @@ -153,14 +156,14 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned lhs_info.k0 = k0_value; lhs_info.v0 = v0_value; lhs_info.interleave = i_value_lhs; - lhs_info.transpose = false; + lhs_info.transpose = lhs_transpose; GEMMRHSMatrixInfo rhs_info; rhs_info.n0 = n0_value; rhs_info.k0 = k0_value; rhs_info.h0 = h0_value; rhs_info.interleave = i_value_rhs; - rhs_info.transpose = true; + rhs_info.transpose = !lhs_transpose; GEMMKernelInfo kernel_info; kernel_info.m = M; @@ -209,7 +212,7 @@ TEST_SUITE(CL) TEST_SUITE(GEMMMatrixMultiplyReshaped) TEST_SUITE(Float) TEST_SUITE(FP32) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -222,14 +225,15 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi i_values_lhs), i_values_rhs), broadcast_bias_values), + lhs_transpose_values), act_values), -m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, act_value) +m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, act_value) { - validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, DataType::F32, act_value); + validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, DataType::F32, act_value); } FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -245,6 +249,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fra a_values), beta_values), broadcast_bias_values), + lhs_transpose_values), act_values)) { // Validate output @@ -252,7 +257,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fra } FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -268,6 +273,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fra a_values), beta_values), broadcast_bias_values), + lhs_transpose_values), act_values)) { // Validate output @@ -275,7 +281,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fra } FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -291,6 +297,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::dataset::make("DataType", DataType::F32)), a_values), beta_values), + lhs_transpose_values), act_values)) { // Validate output @@ -298,7 +305,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, } FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -314,6 +321,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::dataset::make("DataType", DataType::F32)), a_values), beta_values), + lhs_transpose_values), act_values)) { // Validate output -- cgit v1.2.1