From 2732cca12bac29e1515cee1db5005c73893c61b4 Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Tue, 28 May 2019 11:44:41 +0100 Subject: COMPMID-2244: Extend CLFuseBatchNormalization to support DepthwiseConvolution weights Change-Id: I7d1907f35cc4899379073759be2f7cce24e51e9d Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/1327 Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- tests/validation/CL/FuseBatchNormalization.cpp | 73 ++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) (limited to 'tests/validation/CL/FuseBatchNormalization.cpp') diff --git a/tests/validation/CL/FuseBatchNormalization.cpp b/tests/validation/CL/FuseBatchNormalization.cpp index 92d63c0c3d..35414b765a 100644 --- a/tests/validation/CL/FuseBatchNormalization.cpp +++ b/tests/validation/CL/FuseBatchNormalization.cpp @@ -44,6 +44,8 @@ AbsoluteTolerance absolute_tolerance_f16(0.2f); template using CLFuseBatchNormalizationConvFixture = FuseBatchNormalizationFixture; +template +using CLFuseBatchNormalizationDWCFixture = FuseBatchNormalizationFixture; // *INDENT-OFF* // clang-format off @@ -140,6 +142,77 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationConvFixture, fram TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float TEST_SUITE_END() // Convolution + +TEST_SUITE(DepthwiseConvolution) +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLFuseBatchNormalizationDWCFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine( + datasets::Small3DShapes(), + framework::dataset::make("DataType", { DataType::F32 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f32); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationDWCFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine( + datasets::Large3DShapes(), + framework::dataset::make("DataType", { DataType::F32 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f32); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f32); +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLFuseBatchNormalizationDWCFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine( + datasets::Small3DShapes(), + framework::dataset::make("DataType", { DataType::F16 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f16); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationDWCFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine( + datasets::Large3DShapes(), + framework::dataset::make("DataType", { DataType::F16 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f16); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f16); +} + +TEST_SUITE_END() // FP16 +TEST_SUITE_END() // Float +TEST_SUITE_END() // DepthwiseConvolution + TEST_SUITE_END() // FuseBatchNormalization TEST_SUITE_END() // CL } // namespace validation -- cgit v1.2.1