aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/validation_new/CL/DirectConvolutionLayer.cpp26
1 files changed, 26 insertions, 0 deletions
diff --git a/tests/validation_new/CL/DirectConvolutionLayer.cpp b/tests/validation_new/CL/DirectConvolutionLayer.cpp
index 1a7cd6b3fb..d82f535136 100644
--- a/tests/validation_new/CL/DirectConvolutionLayer.cpp
+++ b/tests/validation_new/CL/DirectConvolutionLayer.cpp
@@ -46,6 +46,9 @@ namespace
constexpr AbsoluteTolerance<float> tolerance_fp16(0.1f); /**< Tolerance for floating point tests */
constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance for floating point tests */
+constexpr AbsoluteTolerance<int8_t> tolerance_qs8(0); /**< Tolerance for fixed point tests */
+constexpr AbsoluteTolerance<int16_t> tolerance_qs16(0); /**< Tolerance for fixed point tests */
+
/** Direct convolution data set. */
const auto data = combine(datasets::SmallDirectConvolutionShapes(),
combine(framework::dataset::make("StrideX", 1, 3),
@@ -85,6 +88,29 @@ FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixture<float>, framework::D
TEST_SUITE_END()
TEST_SUITE_END()
+template <typename T>
+using CLDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFixedPointFixture<CLTensor, CLAccessor, CLDirectConvolutionLayer, T>;
+
+TEST_SUITE(Quantized)
+TEST_SUITE(QS8)
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS8)),
+ framework::dataset::make("FractionalBits", 2, 7)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qs8);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(QS16)
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS16)),
+ framework::dataset::make("FractionalBits", 2, 15)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qs16);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
TEST_SUITE_END()
TEST_SUITE_END()
} // namespace validation