From 0c17aa25a4f7bc812707150b91930f0cf8e75294 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 27 Sep 2019 09:23:15 +0100 Subject: COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16 Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1991 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 162 ++++++++++++++++++--- tests/validation/fixtures/GEMMFixture.h | 24 ++- tests/validation/reference/GEMM.cpp | 55 ++++++- tests/validation/reference/GEMM.h | 5 +- 4 files changed, 218 insertions(+), 28 deletions(-) (limited to 'tests') diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index 99f5ffe191..b885bfe4af 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -60,10 +60,20 @@ using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture; +// Fixture for CLGEMMMatrixMultiplyReshaped mixed precision +template +using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture = + GEMMMatrixMultiplyReshapedValidationFixture; + // Fixture for CLGEMMMatrixMultiplyReshaped3D template using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture; +// Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision +template +using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture = + GEMMMatrixMultiplyReshaped3DValidationFixture; + namespace { // *INDENT-OFF* @@ -71,15 +81,12 @@ namespace RelativeTolerance rel_tolerance_f32(0.001f); constexpr float abs_tolerance_f32(0.0001f); +RelativeTolerance rel_tolerance_f16_mixed_precision(0.001f); +constexpr float abs_tolerance_f16_mixed_precision(0.01f); + RelativeTolerance rel_tolerance_f16(0.001f); constexpr float abs_tolerance_f16(0.01f); -/** Alpha values to test - Precommit */ -const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); - -/** Beta values to test - Precommit */ -const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} ); - /** M values to test */ const auto m_values = framework::dataset::make("M", 37); @@ -105,6 +112,12 @@ const auto act_values = framework::dataset::make("Activation", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), }); +/** Alpha values to test - Precommit */ +const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} ); + +/** Beta values to test - Precommit */ +const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} ); + /** M0 values to test - Precommit */ const auto m0_values_precommit = framework::dataset::make("M0", { 4 }); @@ -120,6 +133,12 @@ const auto v0_values_precommit = framework::dataset::make("V0", 1, 3); /** H0 values to test - Precommit */ const auto h0_values_precommit = framework::dataset::make("H0", 1, 3); +/** Alpha values to test - Nightly */ +const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} ); + +/** Beta values to test - Nightly */ +const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} ); + /** M0 values to test - Nightly */ const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 }); @@ -167,8 +186,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fra i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -191,8 +210,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fra i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -216,8 +235,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), lhs_transpose_values), act_values)) { @@ -240,8 +259,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), lhs_transpose_values), act_values)) { @@ -266,8 +285,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fram i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -290,8 +309,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fram i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -315,8 +334,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), lhs_transpose_values), act_values)) { @@ -339,8 +358,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), lhs_transpose_values), act_values)) { @@ -348,6 +367,105 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); } TEST_SUITE_END() // FP16 + +TEST_SUITE(MixedPrecision) + +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_precommit), + beta_values_precommit), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_nightly), + beta_values_nightly), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_precommit), + beta_values_precommit), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_nightly), + beta_values_nightly), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} +TEST_SUITE_END() // MixedPrecision TEST_SUITE_END() // Float TEST_SUITE_END() // GEMMMatrixMultiplyReshaped TEST_SUITE_END() // CL diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index 854cc4a22b..bf919c9b09 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -667,7 +667,7 @@ protected: SimpleTensor _reference{}; }; -template +template class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture { public: @@ -734,6 +734,7 @@ protected: kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; kernel_info.activation_info = act_info; + kernel_info.fp_mixed_precision = fp_mixed_precision; // The output tensor will be auto-initialized within the function @@ -807,14 +808,21 @@ protected: } } - return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + if(fp_mixed_precision) + { + return reference::activation_layer(reference::gemm_mixed_precision(lhs, rhs, bias, alpha, beta), act_info); + } + else + { + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } } TensorType _target{}; SimpleTensor _reference{}; }; -template +template class GEMMMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture { public: @@ -879,6 +887,7 @@ protected: kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = true; kernel_info.activation_info = act_info; + kernel_info.fp_mixed_precision = fp_mixed_precision; // The output tensor will be auto-initialized within the function @@ -951,7 +960,14 @@ protected: memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); } - return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + if(fp_mixed_precision) + { + return reference::activation_layer(reference::gemm_mixed_precision(lhs, rhs, bias, alpha, beta), act_info); + } + else + { + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } } TensorType _target{}; diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp index 2feab89950..3c72b94143 100644 --- a/tests/validation/reference/GEMM.cpp +++ b/tests/validation/reference/GEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -84,8 +84,61 @@ SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const S return dst; } +template ::value, int>::type> +SimpleTensor gemm_mixed_precision(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta) +{ + // GEMM mixed-precision combines F32 accumulators with F16 multiplications + // Create reference + SimpleTensor dst{ c.shape(), c.data_type(), 1 }; + + // Compute reference + const int M = a.shape().y(); + const int N = b.shape().x(); + const int K = a.shape().x(); + const int D = a.shape().z(); // Number of matrices in a batch + const int W = a.shape()[3]; // Number of batched-gemm (Winograd case) + + const int a_stride_z = K * M; + const int a_stride_w = K * M * D; + + const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions + const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions + + const int c_stride_z = N * M; + const int c_stride_w = N * M * D; + + for(int w = 0; w < W; ++w) + { + for(int depth = 0; depth < D; ++depth) + { + const int base_addr_a = depth * a_stride_z + w * a_stride_w; + const int base_addr_b = depth * b_stride_z + w * b_stride_w; + const int base_addr_c = depth * c_stride_z + w * c_stride_w; + + for(int row = 0; row < M; ++row) + { + for(int col = 0; col < N; ++col) + { + float acc(0); + + for(int k = 0; k < K; ++k) + { + acc += static_cast(a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]); + } + + // Finalize the result: alpha * A * B + beta * C + dst[base_addr_c + col + row * N] = static_cast(alpha * acc + beta * c[base_addr_c + col + row * N]); + } + } + } + } + + return dst; +} + template SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); template SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); +template SimpleTensor gemm_mixed_precision(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h index 39007c60bc..9bcd640770 100644 --- a/tests/validation/reference/GEMM.h +++ b/tests/validation/reference/GEMM.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -38,6 +38,9 @@ namespace reference template ::value, int>::type = 0> SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); +template ::value, int>::type = 0> +SimpleTensor gemm_mixed_precision(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); + } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1