From 0c17aa25a4f7bc812707150b91930f0cf8e75294 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 27 Sep 2019 09:23:15 +0100 Subject: COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16 Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1991 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 162 ++++++++++++++++++--- 1 file changed, 140 insertions(+), 22 deletions(-) (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp') diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index 99f5ffe191..b885bfe4af 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -60,10 +60,20 @@ using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture; +// Fixture for CLGEMMMatrixMultiplyReshaped mixed precision +template +using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture = + GEMMMatrixMultiplyReshapedValidationFixture; + // Fixture for CLGEMMMatrixMultiplyReshaped3D template using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture; +// Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision +template +using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture = + GEMMMatrixMultiplyReshaped3DValidationFixture; + namespace { // *INDENT-OFF* @@ -71,15 +81,12 @@ namespace RelativeTolerance rel_tolerance_f32(0.001f); constexpr float abs_tolerance_f32(0.0001f); +RelativeTolerance rel_tolerance_f16_mixed_precision(0.001f); +constexpr float abs_tolerance_f16_mixed_precision(0.01f); + RelativeTolerance rel_tolerance_f16(0.001f); constexpr float abs_tolerance_f16(0.01f); -/** Alpha values to test - Precommit */ -const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); - -/** Beta values to test - Precommit */ -const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} ); - /** M values to test */ const auto m_values = framework::dataset::make("M", 37); @@ -105,6 +112,12 @@ const auto act_values = framework::dataset::make("Activation", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), }); +/** Alpha values to test - Precommit */ +const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} ); + +/** Beta values to test - Precommit */ +const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} ); + /** M0 values to test - Precommit */ const auto m0_values_precommit = framework::dataset::make("M0", { 4 }); @@ -120,6 +133,12 @@ const auto v0_values_precommit = framework::dataset::make("V0", 1, 3); /** H0 values to test - Precommit */ const auto h0_values_precommit = framework::dataset::make("H0", 1, 3); +/** Alpha values to test - Nightly */ +const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} ); + +/** Beta values to test - Nightly */ +const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} ); + /** M0 values to test - Nightly */ const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 }); @@ -167,8 +186,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fra i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -191,8 +210,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fra i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -216,8 +235,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), lhs_transpose_values), act_values)) { @@ -240,8 +259,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), lhs_transpose_values), act_values)) { @@ -266,8 +285,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fram i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -290,8 +309,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fram i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -315,8 +334,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), lhs_transpose_values), act_values)) { @@ -339,8 +358,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), lhs_transpose_values), act_values)) { @@ -348,6 +367,105 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); } TEST_SUITE_END() // FP16 + +TEST_SUITE(MixedPrecision) + +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_precommit), + beta_values_precommit), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_nightly), + beta_values_nightly), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_precommit), + beta_values_precommit), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_nightly), + beta_values_nightly), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} +TEST_SUITE_END() // MixedPrecision TEST_SUITE_END() // Float TEST_SUITE_END() // GEMMMatrixMultiplyReshaped TEST_SUITE_END() // CL -- cgit v1.2.1