aboutsummaryrefslogtreecommitdiff
path: root/tests/validation
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation')
-rw-r--r--tests/validation/CL/GEMM.cpp40
-rw-r--r--tests/validation/fixtures/GEMMFixture.h9
2 files changed, 45 insertions, 4 deletions
diff --git a/tests/validation/CL/GEMM.cpp b/tests/validation/CL/GEMM.cpp
index 217edf4438..4e6cf826fa 100644
--- a/tests/validation/CL/GEMM.cpp
+++ b/tests/validation/CL/GEMM.cpp
@@ -250,8 +250,44 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixedPointFixture<int16_t>, framework::Da
TEST_SUITE_END()
TEST_SUITE_END()
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE(OUTPUT_3D)
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMM3DDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMM3DDataset(),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+}
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMM3DDataset(),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMM3DDataset(),
+ framework::dataset::make("DataType",
+ DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
+}
+TEST_SUITE_END() // FP16
+
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // OUTPUT_3D
+
+TEST_SUITE_END() // GEMM
+TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index e807ad8c06..8dd2998377 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -85,7 +85,12 @@ protected:
// Create and configure function
FunctionType gemm;
- gemm.configure(&a, &b, &c, &dst, alpha, beta);
+ // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output.
+ // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 1),
+ // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output).
+ bool is_output_reinterpreted_as_3D = output_shape.num_dimensions() > shape_a.num_dimensions();
+ gemm.configure(&a, &b, &c, &dst, alpha, beta,
+ GEMMInfo(false, false, false, is_output_reinterpreted_as_3D ? output_shape[2] : 1));
ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);