aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/PoolingLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/PoolingLayer.cpp')
-rw-r--r--tests/validation/CL/PoolingLayer.cpp12
1 files changed, 5 insertions, 7 deletions
diff --git a/tests/validation/CL/PoolingLayer.cpp b/tests/validation/CL/PoolingLayer.cpp
index dc9604423f..9da4c55c78 100644
--- a/tests/validation/CL/PoolingLayer.cpp
+++ b/tests/validation/CL/PoolingLayer.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/CL/functions/CLPoolingLayer.h"
#include "tests/CL/CLAccessor.h"
#include "tests/PaddingCalculator.h"
+#include "tests/datasets/PoolingLayerDataset.h"
#include "tests/datasets/PoolingTypesDataset.h"
#include "tests/datasets/ShapeDatasets.h"
#include "tests/framework/Asserts.h"
@@ -43,12 +44,6 @@ namespace validation
{
namespace
{
-/** Failing data set */
-const auto PoolingLayerDatasetSpecial = ((((framework::dataset::make("Shape", TensorShape{ 60U, 52U, 3U, 5U })
- * framework::dataset::make("PoolType", PoolingType::AVG))
- * framework::dataset::make("PoolingSize", Size2D(100, 100)))
- * framework::dataset::make("PadStride", PadStrideInfo(5, 5, 50, 50)))
- * framework::dataset::make("ExcludePadding", true));
/** Input data set for floating-point data types */
const auto PoolingLayerDatasetFP = combine(combine(combine(datasets::PoolingTypes(), framework::dataset::make("PoolingSize", { Size2D(2, 2), Size2D(3, 3), Size2D(7, 7), Size2D(9, 9), Size2D(5, 7), Size2D(7, 9) })),
framework::dataset::make("PadStride", { PadStrideInfo(1, 1, 0, 0), PadStrideInfo(2, 1, 0, 0), PadStrideInfo(1, 2, 1, 1), PadStrideInfo(2, 2, 1, 0) })),
@@ -121,9 +116,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
template <typename T>
using CLPoolingLayerFixture = PoolingLayerValidationFixture<CLTensor, CLAccessor, CLPoolingLayer, T>;
+template <typename T>
+using CLSpecialPoolingLayerFixture = SpecialPoolingLayerValidationFixture<CLTensor, CLAccessor, CLPoolingLayer, T>;
+
TEST_SUITE(Float)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSpecial, CLPoolingLayerFixture<float>, framework::DatasetMode::ALL, PoolingLayerDatasetSpecial * framework::dataset::make("DataType", DataType::F32))
+FIXTURE_DATA_TEST_CASE(RunSpecial, CLSpecialPoolingLayerFixture<float>, framework::DatasetMode::ALL, datasets::PoolingLayerDatasetSpecial() * framework::dataset::make("DataType", DataType::F32))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);