aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2017-08-14 11:26:37 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitdef665a1a2e92baa1cfb192b65425b91ff6046b3 (patch)
tree03271d3e78190d81709618f40191d105b7c94917 /tests
parentfc2817dc0436ef2d5064df0a061aafd3d324d894 (diff)
downloadComputeLibrary-def665a1a2e92baa1cfb192b65425b91ff6046b3.tar.gz
COMPMID-474 - Add support for QS8/QS16 DirectConvolution CL
Change-Id: I537e4acbc02c8d880ff8630ea62223e0f1a1dda3 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/82875 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation_new/CL/DirectConvolutionLayer.cpp26
1 files changed, 26 insertions, 0 deletions
diff --git a/tests/validation_new/CL/DirectConvolutionLayer.cpp b/tests/validation_new/CL/DirectConvolutionLayer.cpp
index 1a7cd6b3fb..d82f535136 100644
--- a/tests/validation_new/CL/DirectConvolutionLayer.cpp
+++ b/tests/validation_new/CL/DirectConvolutionLayer.cpp
@@ -46,6 +46,9 @@ namespace
constexpr AbsoluteTolerance<float> tolerance_fp16(0.1f); /**< Tolerance for floating point tests */
constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance for floating point tests */
+constexpr AbsoluteTolerance<int8_t> tolerance_qs8(0); /**< Tolerance for fixed point tests */
+constexpr AbsoluteTolerance<int16_t> tolerance_qs16(0); /**< Tolerance for fixed point tests */
+
/** Direct convolution data set. */
const auto data = combine(datasets::SmallDirectConvolutionShapes(),
combine(framework::dataset::make("StrideX", 1, 3),
@@ -85,6 +88,29 @@ FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixture<float>, framework::D
TEST_SUITE_END()
TEST_SUITE_END()
+template <typename T>
+using CLDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFixedPointFixture<CLTensor, CLAccessor, CLDirectConvolutionLayer, T>;
+
+TEST_SUITE(Quantized)
+TEST_SUITE(QS8)
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS8)),
+ framework::dataset::make("FractionalBits", 2, 7)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qs8);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(QS16)
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::QS16)),
+ framework::dataset::make("FractionalBits", 2, 15)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qs16);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
TEST_SUITE_END()
TEST_SUITE_END()
} // namespace validation