aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2018-10-30 11:18:37 +0000
committerAnthony Barbier <Anthony.barbier@arm.com>2018-11-15 09:40:26 +0000
commit0e37b5c0e2caaaf18117ec3b1cecff6a85c184f3 (patch)
tree943f5e775cfddfde5e4a3be8cd51b6c6189a4ebe /tests/validation/CL
parentcd96a26f67bfbb9b0efe6e0e2b229d0b46b4e3e6 (diff)
downloadComputeLibrary-0e37b5c0e2caaaf18117ec3b1cecff6a85c184f3.tar.gz
COMPMID-1708: Improve GEMM test coverage.
Added test cases to exercise the code path where the reshaping of B is performed on the fly. Change-Id: Ifa4348e1054dc0019be3927f482adf64b18fd554
Diffstat (limited to 'tests/validation/CL')
-rw-r--r--tests/validation/CL/GEMM.cpp51
1 files changed, 32 insertions, 19 deletions
diff --git a/tests/validation/CL/GEMM.cpp b/tests/validation/CL/GEMM.cpp
index ff2071a756..376032c382 100644
--- a/tests/validation/CL/GEMM.cpp
+++ b/tests/validation/CL/GEMM.cpp
@@ -108,10 +108,10 @@ template <typename T>
using CLGEMMFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T>;
template <typename T>
-using CLGEMMOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, true>;
+using CLGEMMOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, false, true>;
template <typename T>
-using CLGEMMInputOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, true, true>;
+using CLGEMMInputOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, true, true>;
TEST_SUITE(TRANSPOSE_1XW)
using CLGEMMTranspose1xW = CLSynthetizeFunctionWithZeroConstantBorder<CLGEMMTranspose1xWKernel, 4>;
@@ -128,13 +128,16 @@ TEST_SUITE_END() //TRANSPOSE_1XW
TEST_SUITE(Float)
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType",
- DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true })),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
@@ -142,12 +145,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NI
TEST_SUITE_END()
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
@@ -158,13 +165,15 @@ TEST_SUITE_END()
TEST_SUITE(INPUT_OUTPUT_3D)
TEST_SUITE(Float)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMInputOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMInputOutput3DDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMInputOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMInputOutput3DDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
@@ -173,15 +182,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<float>, framework::D
TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMInputOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMInputOutput3DDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMInputOutput3DDataset(),
- framework::dataset::make("DataType",
- DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMInputOutput3DDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
@@ -194,13 +204,15 @@ TEST_SUITE_END() // INPUT_OUTPUT_3D
TEST_SUITE(OUTPUT_3D)
TEST_SUITE(Float)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMOutput3DDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMOutput3DDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
@@ -209,15 +221,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<float>, framework::Datase
TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMOutput3DDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMOutput3DDataset(),
- framework::dataset::make("DataType",
- DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMOutput3DDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);