From def665a1a2e92baa1cfb192b65425b91ff6046b3 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Mon, 14 Aug 2017 11:26:37 +0100 Subject: COMPMID-474 - Add support for QS8/QS16 DirectConvolution CL Change-Id: I537e4acbc02c8d880ff8630ea62223e0f1a1dda3 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/82875 Tested-by: Kaizen Reviewed-by: Pablo Tello --- tests/validation_new/CL/DirectConvolutionLayer.cpp | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) (limited to 'tests/validation_new') diff --git a/tests/validation_new/CL/DirectConvolutionLayer.cpp b/tests/validation_new/CL/DirectConvolutionLayer.cpp index 1a7cd6b3fb..d82f535136 100644 --- a/tests/validation_new/CL/DirectConvolutionLayer.cpp +++ b/tests/validation_new/CL/DirectConvolutionLayer.cpp @@ -46,6 +46,9 @@ namespace constexpr AbsoluteTolerance tolerance_fp16(0.1f); /**< Tolerance for floating point tests */ constexpr AbsoluteTolerance tolerance_fp32(0.001f); /**< Tolerance for floating point tests */ +constexpr AbsoluteTolerance tolerance_qs8(0); /**< Tolerance for fixed point tests */ +constexpr AbsoluteTolerance tolerance_qs16(0); /**< Tolerance for fixed point tests */ + /** Direct convolution data set. */ const auto data = combine(datasets::SmallDirectConvolutionShapes(), combine(framework::dataset::make("StrideX", 1, 3), @@ -85,6 +88,29 @@ FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixture, framework::D TEST_SUITE_END() TEST_SUITE_END() +template +using CLDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFixedPointFixture; + +TEST_SUITE(Quantized) +TEST_SUITE(QS8) +FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS8)), + framework::dataset::make("FractionalBits", 2, 7))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_qs8); +} +TEST_SUITE_END() + +TEST_SUITE(QS16) +FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS16)), + framework::dataset::make("FractionalBits", 2, 15))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_qs16); +} +TEST_SUITE_END() +TEST_SUITE_END() + TEST_SUITE_END() TEST_SUITE_END() } // namespace validation -- cgit v1.2.1