aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp66
1 files changed, 66 insertions, 0 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index 6a1d495576..bd0cd03ca7 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -131,6 +131,32 @@ const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, fals
/** Broadcast bias from vector to matrix */
const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
+/** Boundary handling cases for testing partial/non-partial (full) block dimensions, resulting from different combinations
+ * of M, M0, N and N0 values.
+ * M0 and N0 are kept constant, while the different test cases need to vary M and N.
+ *
+ * Eg. M = 64 and N = 33 result in a block dimension that has no partial blocks (all full blocks) in Y dimension and
+ * parital blocks in X dimension.
+ */
+const auto boundary_handling_cases = combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ // Large k to force potential out-of-bound reads on input0
+ framework::dataset::make("K", 315),
+ // Batch size == 1 to force potential out-of-bound reads on input0
+ framework::dataset::make("batch_size", 1)),
+ framework::dataset::make("M0", 4)),
+ framework::dataset::make("N0", 4)),
+ framework::dataset::make("K0", 4)),
+ framework::dataset::make("H0", 3)),
+ i_values_rhs),
+ t_values_rhs),
+ framework::dataset::make("export_to_cl_image_rhs", {true, false})),
+ // Only need to test F32 as F16 shares identical boundary handling logics
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("alpha", -0.75f )),
+ framework::dataset::make("beta", -0.35f )),
+ broadcast_bias_values),
+ framework::dataset::make("Activation", ActivationLayerInfo()));
+
/** Configuration test */
bool validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value,
unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value,
@@ -330,6 +356,46 @@ m_value, n_value, m0_value, n0_value, export_to_cl_image)
TEST_SUITE(Float)
TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunPrecommitBoundaryHandlingPartialInXPartialInY, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(
+ framework::dataset::make("M", 3),
+ framework::dataset::make("N", 1)),
+ boundary_handling_cases))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunPrecommitBoundaryHandlingPartialInXFullInY, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(
+ framework::dataset::make("M", 64),
+ framework::dataset::make("N", 43)),
+ boundary_handling_cases))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunPrecommitBoundaryHandlingFullInXFullInY, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(
+ framework::dataset::make("M", 64),
+ framework::dataset::make("N", 32)),
+ boundary_handling_cases))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunPrecommitBoundaryHandlingFullInXPartialInY, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(
+ framework::dataset::make("M", 37),
+ framework::dataset::make("N", 32)),
+ boundary_handling_cases))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+}
+
FIXTURE_DATA_TEST_CASE(RunPrecommit, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,