aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-27 09:23:15 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-30 08:28:43 +0000
commit0c17aa25a4f7bc812707150b91930f0cf8e75294 (patch)
tree29088e00bd7ba443dc122ad3436b0a4ef369a102 /tests
parent40958adf8bad8fd9fefe591ee55a381f7bbb6fea (diff)
downloadComputeLibrary-0c17aa25a4f7bc812707150b91930f0cf8e75294.tar.gz
COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16
Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1991 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp162
-rw-r--r--tests/validation/fixtures/GEMMFixture.h24
-rw-r--r--tests/validation/reference/GEMM.cpp55
-rw-r--r--tests/validation/reference/GEMM.h5
4 files changed, 218 insertions, 28 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index 99f5ffe191..b885bfe4af 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -60,10 +60,20 @@ using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction<CLGEMMMatrixMultiplyRe
template <typename T>
using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
+// Fixture for CLGEMMMatrixMultiplyReshaped mixed precision
+template <typename T>
+using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture =
+ GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
+
// Fixture for CLGEMMMatrixMultiplyReshaped3D
template <typename T>
using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
+// Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision
+template <typename T>
+using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture =
+ GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
+
namespace
{
// *INDENT-OFF*
@@ -71,15 +81,12 @@ namespace
RelativeTolerance<float> rel_tolerance_f32(0.001f);
constexpr float abs_tolerance_f32(0.0001f);
+RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f);
+constexpr float abs_tolerance_f16_mixed_precision(0.01f);
+
RelativeTolerance<float> rel_tolerance_f16(0.001f);
constexpr float abs_tolerance_f16(0.01f);
-/** Alpha values to test - Precommit */
-const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
-
-/** Beta values to test - Precommit */
-const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} );
-
/** M values to test */
const auto m_values = framework::dataset::make("M", 37);
@@ -105,6 +112,12 @@ const auto act_values = framework::dataset::make("Activation",
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
});
+/** Alpha values to test - Precommit */
+const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} );
+
+/** Beta values to test - Precommit */
+const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} );
+
/** M0 values to test - Precommit */
const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
@@ -120,6 +133,12 @@ const auto v0_values_precommit = framework::dataset::make("V0", 1, 3);
/** H0 values to test - Precommit */
const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
+/** Alpha values to test - Nightly */
+const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
+
+/** Beta values to test - Nightly */
+const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
+
/** M0 values to test - Nightly */
const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 });
@@ -167,8 +186,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values),
- beta_values),
+ a_values_precommit),
+ beta_values_precommit),
broadcast_bias_values),
lhs_transpose_values),
act_values))
@@ -191,8 +210,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values),
- beta_values),
+ a_values_nightly),
+ beta_values_nightly),
broadcast_bias_values),
lhs_transpose_values),
act_values))
@@ -216,8 +235,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values),
- beta_values),
+ a_values_precommit),
+ beta_values_precommit),
lhs_transpose_values),
act_values))
{
@@ -240,8 +259,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values),
- beta_values),
+ a_values_nightly),
+ beta_values_nightly),
lhs_transpose_values),
act_values))
{
@@ -266,8 +285,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, fram
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values),
- beta_values),
+ a_values_precommit),
+ beta_values_precommit),
broadcast_bias_values),
lhs_transpose_values),
act_values))
@@ -290,8 +309,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, fram
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values),
- beta_values),
+ a_values_nightly),
+ beta_values_nightly),
broadcast_bias_values),
lhs_transpose_values),
act_values))
@@ -315,8 +334,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values),
- beta_values),
+ a_values_precommit),
+ beta_values_precommit),
lhs_transpose_values),
act_values))
{
@@ -339,8 +358,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values),
- beta_values),
+ a_values_nightly),
+ beta_values_nightly),
lhs_transpose_values),
act_values))
{
@@ -348,6 +367,105 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>,
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
TEST_SUITE_END() // FP16
+
+TEST_SUITE(MixedPrecision)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ v0_values_precommit),
+ h0_values_precommit),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values_precommit),
+ beta_values_precommit),
+ broadcast_bias_values),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_nightly),
+ n0_values_nightly),
+ k0_values_nightly),
+ v0_values_nightly),
+ h0_values_nightly),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values_nightly),
+ beta_values_nightly),
+ broadcast_bias_values),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_w_values,
+ m_h_values),
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ v0_values_precommit),
+ h0_values_precommit),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values_precommit),
+ beta_values_precommit),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_w_values,
+ m_h_values),
+ n_values),
+ k_values),
+ b_values),
+ m0_values_nightly),
+ n0_values_nightly),
+ k0_values_nightly),
+ v0_values_nightly),
+ h0_values_nightly),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values_nightly),
+ beta_values_nightly),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+TEST_SUITE_END() // MixedPrecision
TEST_SUITE_END() // Float
TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
TEST_SUITE_END() // CL
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 854cc4a22b..bf919c9b09 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -667,7 +667,7 @@ protected:
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false>
class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture
{
public:
@@ -734,6 +734,7 @@ protected:
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = broadcast_bias;
kernel_info.activation_info = act_info;
+ kernel_info.fp_mixed_precision = fp_mixed_precision;
// The output tensor will be auto-initialized within the function
@@ -807,14 +808,21 @@ protected:
}
}
- return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ if(fp_mixed_precision)
+ {
+ return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
+ else
+ {
+ return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false>
class GEMMMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture
{
public:
@@ -879,6 +887,7 @@ protected:
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = true;
kernel_info.activation_info = act_info;
+ kernel_info.fp_mixed_precision = fp_mixed_precision;
// The output tensor will be auto-initialized within the function
@@ -951,7 +960,14 @@ protected:
memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
}
- return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ if(fp_mixed_precision)
+ {
+ return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
+ else
+ {
+ return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
}
TensorType _target{};
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp
index 2feab89950..3c72b94143 100644
--- a/tests/validation/reference/GEMM.cpp
+++ b/tests/validation/reference/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -84,8 +84,61 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
return dst;
}
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
+SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
+{
+ // GEMM mixed-precision combines F32 accumulators with F16 multiplications
+ // Create reference
+ SimpleTensor<T> dst{ c.shape(), c.data_type(), 1 };
+
+ // Compute reference
+ const int M = a.shape().y();
+ const int N = b.shape().x();
+ const int K = a.shape().x();
+ const int D = a.shape().z(); // Number of matrices in a batch
+ const int W = a.shape()[3]; // Number of batched-gemm (Winograd case)
+
+ const int a_stride_z = K * M;
+ const int a_stride_w = K * M * D;
+
+ const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
+ const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
+
+ const int c_stride_z = N * M;
+ const int c_stride_w = N * M * D;
+
+ for(int w = 0; w < W; ++w)
+ {
+ for(int depth = 0; depth < D; ++depth)
+ {
+ const int base_addr_a = depth * a_stride_z + w * a_stride_w;
+ const int base_addr_b = depth * b_stride_z + w * b_stride_w;
+ const int base_addr_c = depth * c_stride_z + w * c_stride_w;
+
+ for(int row = 0; row < M; ++row)
+ {
+ for(int col = 0; col < N; ++col)
+ {
+ float acc(0);
+
+ for(int k = 0; k < K; ++k)
+ {
+ acc += static_cast<float>(a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]);
+ }
+
+ // Finalize the result: alpha * A * B + beta * C
+ dst[base_addr_c + col + row * N] = static_cast<T>(alpha * acc + beta * c[base_addr_c + col + row * N]);
+ }
+ }
+ }
+ }
+
+ return dst;
+}
+
template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h
index 39007c60bc..9bcd640770 100644
--- a/tests/validation/reference/GEMM.h
+++ b/tests/validation/reference/GEMM.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,6 +38,9 @@ namespace reference
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta);
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
+SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta);
+
} // namespace reference
} // namespace validation
} // namespace test