aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMMatrixMultiply.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiply.cpp')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiply.cpp22
1 files changed, 14 insertions, 8 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiply.cpp b/tests/validation/CL/GEMMMatrixMultiply.cpp
index 5d2e211d91..fdf7f503ec 100644
--- a/tests/validation/CL/GEMMMatrixMultiply.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiply.cpp
@@ -72,8 +72,16 @@ const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
/** Beta values to test */
const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} );
-/** M values to test */
-const auto m_values = framework::dataset::make("M", {37, 1});
+/** M, N combinations to test
+ * 1: Special 1x1 case
+ * 2: Special multples of processor size in both dimensions
+ * 3: Non multiples of processor size in both dimensions
+ * 4: Special 1x1003 case
+*/
+const auto m_n_values = zip(
+ framework::dataset::make("M", {1, 16, 37, 1}),
+ framework::dataset::make("N", {1, 16, 51, 1003})
+ );
/** N values to test */
const auto n_values = framework::dataset::make("N", {51, 1003});
@@ -247,9 +255,8 @@ TEST_SUITE(Float)
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
- n_values),
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_n_values,
k_values),
b_values),
alpha_values),
@@ -287,9 +294,8 @@ TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyNativeFixture<half>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
- n_values),
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_n_values,
k_values),
b_values),
alpha_values),