diff options
author | steniu01 <steven.niu@arm.com> | 2017-08-09 16:26:22 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | db00668890e1aba956e02fa02e1383b54dfd1435 (patch) | |
tree | e20cc07d9bc9eb4bf613213007a2351f5d4eec60 /tests | |
parent | ff6ab352f4f6715b7028a39d8722759d19d2524b (diff) | |
download | ComputeLibrary-db00668890e1aba956e02fa02e1383b54dfd1435.tar.gz |
COMPMID-478 Implemnt CL direct convolution 5x5
Change-Id: I4b975aff310cda9964d8c5dcee182d5d5c82741b
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83474
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/datasets_new/ShapeDatasets.h | 2 | ||||
-rw-r--r-- | tests/validation_new/CL/DirectConvolutionLayer.cpp | 17 |
2 files changed, 15 insertions, 4 deletions
diff --git a/tests/datasets_new/ShapeDatasets.h b/tests/datasets_new/ShapeDatasets.h index 14f7851621..f6cd3f2d0e 100644 --- a/tests/datasets_new/ShapeDatasets.h +++ b/tests/datasets_new/ShapeDatasets.h @@ -115,7 +115,7 @@ public: SmallDirectConvolutionShapes() : ShapeDataset("InputShape", { - TensorShape{ 3U, 3U, 3U, 2U, 4U, 5U }, + TensorShape{ 5U, 5U, 3U, 2U, 4U, 5U }, TensorShape{ 32U, 37U, 3U }, TensorShape{ 13U, 15U, 8U, 3U } }) diff --git a/tests/validation_new/CL/DirectConvolutionLayer.cpp b/tests/validation_new/CL/DirectConvolutionLayer.cpp index d82f535136..1c698ace0f 100644 --- a/tests/validation_new/CL/DirectConvolutionLayer.cpp +++ b/tests/validation_new/CL/DirectConvolutionLayer.cpp @@ -50,6 +50,17 @@ constexpr AbsoluteTolerance<int8_t> tolerance_qs8(0); /**< Tolerance for fixed constexpr AbsoluteTolerance<int16_t> tolerance_qs16(0); /**< Tolerance for fixed point tests */ /** Direct convolution data set. */ +const auto data_quantized = combine(datasets::SmallDirectConvolutionShapes(), + combine(framework::dataset::make("StrideX", 1, 3), + combine(framework::dataset::make("StrideY", 1, 3), + combine(concat(combine(framework::dataset::make("PadX", 0), + combine(framework::dataset::make("PadY", 0), + framework::dataset::make("KernelSize", 1))), + combine(framework::dataset::make("PadX", 0, 2), + combine(framework::dataset::make("PadY", 0, 2), + framework::dataset::make("KernelSize", { 3 })))), + framework::dataset::make("NumKernels", { 1, 4, 8, 16 }))))); + const auto data = combine(datasets::SmallDirectConvolutionShapes(), combine(framework::dataset::make("StrideX", 1, 3), combine(framework::dataset::make("StrideY", 1, 3), @@ -58,7 +69,7 @@ const auto data = combine(datasets::SmallDirectConvolutionShapes(), framework::dataset::make("KernelSize", 1))), combine(framework::dataset::make("PadX", 0, 2), combine(framework::dataset::make("PadY", 0, 2), - framework::dataset::make("KernelSize", 3)))), + framework::dataset::make("KernelSize", { 3, 5 })))), framework::dataset::make("NumKernels", { 1, 4, 8, 16 }))))); } // namespace @@ -93,7 +104,7 @@ using CLDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFix TEST_SUITE(Quantized) TEST_SUITE(QS8) -FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS8)), +FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data_quantized, framework::dataset::make("DataType", DataType::QS8)), framework::dataset::make("FractionalBits", 2, 7))) { // Validate output @@ -102,7 +113,7 @@ FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, f TEST_SUITE_END() TEST_SUITE(QS16) -FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS16)), +FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data_quantized, framework::dataset::make("DataType", DataType::QS16)), framework::dataset::make("FractionalBits", 2, 15))) { // Validate output |