diff options
Diffstat (limited to 'tests/validation_new/CL/DirectConvolutionLayer.cpp')
-rw-r--r-- | tests/validation_new/CL/DirectConvolutionLayer.cpp | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tests/validation_new/CL/DirectConvolutionLayer.cpp b/tests/validation_new/CL/DirectConvolutionLayer.cpp index 1a7cd6b3fb..d82f535136 100644 --- a/tests/validation_new/CL/DirectConvolutionLayer.cpp +++ b/tests/validation_new/CL/DirectConvolutionLayer.cpp @@ -46,6 +46,9 @@ namespace constexpr AbsoluteTolerance<float> tolerance_fp16(0.1f); /**< Tolerance for floating point tests */ constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance for floating point tests */ +constexpr AbsoluteTolerance<int8_t> tolerance_qs8(0); /**< Tolerance for fixed point tests */ +constexpr AbsoluteTolerance<int16_t> tolerance_qs16(0); /**< Tolerance for fixed point tests */ + /** Direct convolution data set. */ const auto data = combine(datasets::SmallDirectConvolutionShapes(), combine(framework::dataset::make("StrideX", 1, 3), @@ -85,6 +88,29 @@ FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixture<float>, framework::D TEST_SUITE_END() TEST_SUITE_END() +template <typename T> +using CLDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFixedPointFixture<CLTensor, CLAccessor, CLDirectConvolutionLayer, T>; + +TEST_SUITE(Quantized) +TEST_SUITE(QS8) +FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS8)), + framework::dataset::make("FractionalBits", 2, 7))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_qs8); +} +TEST_SUITE_END() + +TEST_SUITE(QS16) +FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS16)), + framework::dataset::make("FractionalBits", 2, 15))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_qs16); +} +TEST_SUITE_END() +TEST_SUITE_END() + TEST_SUITE_END() TEST_SUITE_END() } // namespace validation |