diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 162 | ||||
-rw-r--r-- | tests/validation/fixtures/GEMMFixture.h | 24 | ||||
-rw-r--r-- | tests/validation/reference/GEMM.cpp | 55 | ||||
-rw-r--r-- | tests/validation/reference/GEMM.h | 5 |
4 files changed, 218 insertions, 28 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index 99f5ffe191..b885bfe4af 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -60,10 +60,20 @@ using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction<CLGEMMMatrixMultiplyRe template <typename T> using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>; +// Fixture for CLGEMMMatrixMultiplyReshaped mixed precision +template <typename T> +using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture = + GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>; + // Fixture for CLGEMMMatrixMultiplyReshaped3D template <typename T> using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>; +// Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision +template <typename T> +using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture = + GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>; + namespace { // *INDENT-OFF* @@ -71,15 +81,12 @@ namespace RelativeTolerance<float> rel_tolerance_f32(0.001f); constexpr float abs_tolerance_f32(0.0001f); +RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f); +constexpr float abs_tolerance_f16_mixed_precision(0.01f); + RelativeTolerance<float> rel_tolerance_f16(0.001f); constexpr float abs_tolerance_f16(0.01f); -/** Alpha values to test - Precommit */ -const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); - -/** Beta values to test - Precommit */ -const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} ); - /** M values to test */ const auto m_values = framework::dataset::make("M", 37); @@ -105,6 +112,12 @@ const auto act_values = framework::dataset::make("Activation", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), }); +/** Alpha values to test - Precommit */ +const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} ); + +/** Beta values to test - Precommit */ +const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} ); + /** M0 values to test - Precommit */ const auto m0_values_precommit = framework::dataset::make("M0", { 4 }); @@ -120,6 +133,12 @@ const auto v0_values_precommit = framework::dataset::make("V0", 1, 3); /** H0 values to test - Precommit */ const auto h0_values_precommit = framework::dataset::make("H0", 1, 3); +/** Alpha values to test - Nightly */ +const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} ); + +/** Beta values to test - Nightly */ +const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} ); + /** M0 values to test - Nightly */ const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 }); @@ -167,8 +186,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -191,8 +210,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -216,8 +235,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), lhs_transpose_values), act_values)) { @@ -240,8 +259,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), lhs_transpose_values), act_values)) { @@ -266,8 +285,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, fram i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -290,8 +309,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, fram i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -315,8 +334,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), lhs_transpose_values), act_values)) { @@ -339,8 +358,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), lhs_transpose_values), act_values)) { @@ -348,6 +367,105 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); } TEST_SUITE_END() // FP16 + +TEST_SUITE(MixedPrecision) + +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_precommit), + beta_values_precommit), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_nightly), + beta_values_nightly), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_precommit), + beta_values_precommit), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_nightly), + beta_values_nightly), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} +TEST_SUITE_END() // MixedPrecision TEST_SUITE_END() // Float TEST_SUITE_END() // GEMMMatrixMultiplyReshaped TEST_SUITE_END() // CL diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index 854cc4a22b..bf919c9b09 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -667,7 +667,7 @@ protected: SimpleTensor<T> _reference{}; }; -template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType> +template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false> class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture { public: @@ -734,6 +734,7 @@ protected: kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; kernel_info.activation_info = act_info; + kernel_info.fp_mixed_precision = fp_mixed_precision; // The output tensor will be auto-initialized within the function @@ -807,14 +808,21 @@ protected: } } - return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info); + if(fp_mixed_precision) + { + return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info); + } + else + { + return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info); + } } TensorType _target{}; SimpleTensor<T> _reference{}; }; -template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType> +template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false> class GEMMMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture { public: @@ -879,6 +887,7 @@ protected: kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = true; kernel_info.activation_info = act_info; + kernel_info.fp_mixed_precision = fp_mixed_precision; // The output tensor will be auto-initialized within the function @@ -951,7 +960,14 @@ protected: memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); } - return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info); + if(fp_mixed_precision) + { + return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info); + } + else + { + return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info); + } } TensorType _target{}; diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp index 2feab89950..3c72b94143 100644 --- a/tests/validation/reference/GEMM.cpp +++ b/tests/validation/reference/GEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -84,8 +84,61 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S return dst; } +template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> +SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta) +{ + // GEMM mixed-precision combines F32 accumulators with F16 multiplications + // Create reference + SimpleTensor<T> dst{ c.shape(), c.data_type(), 1 }; + + // Compute reference + const int M = a.shape().y(); + const int N = b.shape().x(); + const int K = a.shape().x(); + const int D = a.shape().z(); // Number of matrices in a batch + const int W = a.shape()[3]; // Number of batched-gemm (Winograd case) + + const int a_stride_z = K * M; + const int a_stride_w = K * M * D; + + const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions + const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions + + const int c_stride_z = N * M; + const int c_stride_w = N * M * D; + + for(int w = 0; w < W; ++w) + { + for(int depth = 0; depth < D; ++depth) + { + const int base_addr_a = depth * a_stride_z + w * a_stride_w; + const int base_addr_b = depth * b_stride_z + w * b_stride_w; + const int base_addr_c = depth * c_stride_z + w * c_stride_w; + + for(int row = 0; row < M; ++row) + { + for(int col = 0; col < N; ++col) + { + float acc(0); + + for(int k = 0; k < K; ++k) + { + acc += static_cast<float>(a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]); + } + + // Finalize the result: alpha * A * B + beta * C + dst[base_addr_c + col + row * N] = static_cast<T>(alpha * acc + beta * c[base_addr_c + col + row * N]); + } + } + } + } + + return dst; +} + template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta); template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); +template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h index 39007c60bc..9bcd640770 100644 --- a/tests/validation/reference/GEMM.h +++ b/tests/validation/reference/GEMM.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -38,6 +38,9 @@ namespace reference template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0> SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta); +template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0> +SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta); + } // namespace reference } // namespace validation } // namespace test |