aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMReshapeLHSMatrix.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/GEMMReshapeLHSMatrix.cpp')
-rw-r--r--tests/validation/CL/GEMMReshapeLHSMatrix.cpp39
1 files changed, 37 insertions, 2 deletions
diff --git a/tests/validation/CL/GEMMReshapeLHSMatrix.cpp b/tests/validation/CL/GEMMReshapeLHSMatrix.cpp
index 894b83701f..0dd9b811f6 100644
--- a/tests/validation/CL/GEMMReshapeLHSMatrix.cpp
+++ b/tests/validation/CL/GEMMReshapeLHSMatrix.cpp
@@ -66,8 +66,10 @@ const auto b_values = framework::dataset::make("batchsize", 1, 3);
/** M0 values to test */
const auto m0_values_s32 = framework::dataset::make("M0", { 2, 3 });
-const auto m0_values_s16 = framework::dataset::make("M0", { 4, 5 });
-const auto m0_values_s8 = framework::dataset::make("M0", { 6, 7, 8 });
+const auto m0_values_s16 = framework::dataset::make("M0", { 4 });
+const auto m0_values_s16_nt = framework::dataset::make("M0", { 5 });
+const auto m0_values_s8_nt = framework::dataset::make("M0", { 6,7 });
+const auto m0_values_s8 = framework::dataset::make("M0", { 8 });
/** K0 values to test */
const auto k0_values_s32 = framework::dataset::make("K0", { 2, 3 });
@@ -101,6 +103,7 @@ FIXTURE_DATA_TEST_CASE(S32, CLGEMMReshapeLHSMatrixFixture<int>, framework::Datas
// Validate output
validate(CLAccessor(_target), _reference);
}
+
FIXTURE_DATA_TEST_CASE(S16, CLGEMMReshapeLHSMatrixFixture<short>, framework::DatasetMode::ALL,
combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
b_values),
@@ -114,6 +117,7 @@ FIXTURE_DATA_TEST_CASE(S16, CLGEMMReshapeLHSMatrixFixture<short>, framework::Dat
// Validate output
validate(CLAccessor(_target), _reference);
}
+
FIXTURE_DATA_TEST_CASE(S8, CLGEMMReshapeLHSMatrixFixture<char>, framework::DatasetMode::ALL,
combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
b_values),
@@ -128,6 +132,37 @@ FIXTURE_DATA_TEST_CASE(S8, CLGEMMReshapeLHSMatrixFixture<char>, framework::Datas
validate(CLAccessor(_target), _reference);
}
+TEST_SUITE(NotTransposed)
+FIXTURE_DATA_TEST_CASE(S16, CLGEMMReshapeLHSMatrixFixture<short>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
+ b_values),
+ framework::dataset::make("DataType", DataType::S16)),
+ m0_values_s16_nt),
+ k0_values_s16),
+ v0_values),
+ i_values),
+ framework::dataset::make("transpose", { false })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(S8, CLGEMMReshapeLHSMatrixFixture<char>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(),
+ b_values),
+ framework::dataset::make("DataType", DataType::S8)),
+ m0_values_s8_nt),
+ k0_values_s8),
+ v0_values),
+ i_values),
+ framework::dataset::make("transpose", { false })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+TEST_SUITE_END()
+
TEST_SUITE(ReinterpretInputAs3D)
FIXTURE_DATA_TEST_CASE(S32, CLGEMMReshapeLHSMatrix3DFixture<int>, framework::DatasetMode::ALL,
combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(),