aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp21
1 files changed, 13 insertions, 8 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
index b2701e7f6c..d6507a06c4 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
@@ -82,8 +82,15 @@ const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
/** Beta values to test */
const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} );
-/** M values to test */
-const auto m_values = framework::dataset::make("M", {37, 1});
+/** M, N combinations to test
+ * 1: Special 1x1 case
+ * 2: Special multples of processor size in both dimensions
+ * 3: Non multiples of processor size in both dimensions
+*/
+const auto m_n_values = zip(
+ framework::dataset::make("M", {1, 16, 37}),
+ framework::dataset::make("N", {1, 16, 51})
+ );
/** N values to test */
const auto n_values = framework::dataset::make("N", 51);
@@ -235,9 +242,8 @@ TEST_CASE(Negative, framework::DatasetMode::ALL)
TEST_SUITE(Float)
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
- n_values),
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_n_values,
k_values),
b_values),
alpha_values),
@@ -279,9 +285,8 @@ TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
- n_values),
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_n_values,
k_values),
b_values),
alpha_values),