aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorsteniu01 <steven.niu@arm.com>2017-08-09 16:26:22 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitdb00668890e1aba956e02fa02e1383b54dfd1435 (patch)
treee20cc07d9bc9eb4bf613213007a2351f5d4eec60 /tests
parentff6ab352f4f6715b7028a39d8722759d19d2524b (diff)
downloadComputeLibrary-db00668890e1aba956e02fa02e1383b54dfd1435.tar.gz
COMPMID-478 Implemnt CL direct convolution 5x5
Change-Id: I4b975aff310cda9964d8c5dcee182d5d5c82741b Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83474 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/datasets_new/ShapeDatasets.h2
-rw-r--r--tests/validation_new/CL/DirectConvolutionLayer.cpp17
2 files changed, 15 insertions, 4 deletions
diff --git a/tests/datasets_new/ShapeDatasets.h b/tests/datasets_new/ShapeDatasets.h
index 14f7851621..f6cd3f2d0e 100644
--- a/tests/datasets_new/ShapeDatasets.h
+++ b/tests/datasets_new/ShapeDatasets.h
@@ -115,7 +115,7 @@ public:
SmallDirectConvolutionShapes()
: ShapeDataset("InputShape",
{
- TensorShape{ 3U, 3U, 3U, 2U, 4U, 5U },
+ TensorShape{ 5U, 5U, 3U, 2U, 4U, 5U },
TensorShape{ 32U, 37U, 3U },
TensorShape{ 13U, 15U, 8U, 3U }
})
diff --git a/tests/validation_new/CL/DirectConvolutionLayer.cpp b/tests/validation_new/CL/DirectConvolutionLayer.cpp
index d82f535136..1c698ace0f 100644
--- a/tests/validation_new/CL/DirectConvolutionLayer.cpp
+++ b/tests/validation_new/CL/DirectConvolutionLayer.cpp
@@ -50,6 +50,17 @@ constexpr AbsoluteTolerance<int8_t> tolerance_qs8(0); /**< Tolerance for fixed
constexpr AbsoluteTolerance<int16_t> tolerance_qs16(0); /**< Tolerance for fixed point tests */
/** Direct convolution data set. */
+const auto data_quantized = combine(datasets::SmallDirectConvolutionShapes(),
+ combine(framework::dataset::make("StrideX", 1, 3),
+ combine(framework::dataset::make("StrideY", 1, 3),
+ combine(concat(combine(framework::dataset::make("PadX", 0),
+ combine(framework::dataset::make("PadY", 0),
+ framework::dataset::make("KernelSize", 1))),
+ combine(framework::dataset::make("PadX", 0, 2),
+ combine(framework::dataset::make("PadY", 0, 2),
+ framework::dataset::make("KernelSize", { 3 })))),
+ framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
+
const auto data = combine(datasets::SmallDirectConvolutionShapes(),
combine(framework::dataset::make("StrideX", 1, 3),
combine(framework::dataset::make("StrideY", 1, 3),
@@ -58,7 +69,7 @@ const auto data = combine(datasets::SmallDirectConvolutionShapes(),
framework::dataset::make("KernelSize", 1))),
combine(framework::dataset::make("PadX", 0, 2),
combine(framework::dataset::make("PadY", 0, 2),
- framework::dataset::make("KernelSize", 3)))),
+ framework::dataset::make("KernelSize", { 3, 5 })))),
framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
} // namespace
@@ -93,7 +104,7 @@ using CLDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFix
TEST_SUITE(Quantized)
TEST_SUITE(QS8)
-FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS8)),
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data_quantized, framework::dataset::make("DataType", DataType::QS8)),
framework::dataset::make("FractionalBits", 2, 7)))
{
// Validate output
@@ -102,7 +113,7 @@ FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, f
TEST_SUITE_END()
TEST_SUITE(QS16)
-FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data_quantized, framework::dataset::make("DataType", DataType::QS16)),
framework::dataset::make("FractionalBits", 2, 15)))
{
// Validate output