From 2732cca12bac29e1515cee1db5005c73893c61b4 Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Tue, 28 May 2019 11:44:41 +0100 Subject: COMPMID-2244: Extend CLFuseBatchNormalization to support DepthwiseConvolution weights Change-Id: I7d1907f35cc4899379073759be2f7cce24e51e9d Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/1327 Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- tests/validation/CL/FuseBatchNormalization.cpp | 73 ++++++++++++++++++++++ .../fixtures/FuseBatchNormalizationFixture.h | 5 +- 2 files changed, 77 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/validation/CL/FuseBatchNormalization.cpp b/tests/validation/CL/FuseBatchNormalization.cpp index 92d63c0c3d..35414b765a 100644 --- a/tests/validation/CL/FuseBatchNormalization.cpp +++ b/tests/validation/CL/FuseBatchNormalization.cpp @@ -44,6 +44,8 @@ AbsoluteTolerance absolute_tolerance_f16(0.2f); template using CLFuseBatchNormalizationConvFixture = FuseBatchNormalizationFixture; +template +using CLFuseBatchNormalizationDWCFixture = FuseBatchNormalizationFixture; // *INDENT-OFF* // clang-format off @@ -140,6 +142,77 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationConvFixture, fram TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float TEST_SUITE_END() // Convolution + +TEST_SUITE(DepthwiseConvolution) +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLFuseBatchNormalizationDWCFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine( + datasets::Small3DShapes(), + framework::dataset::make("DataType", { DataType::F32 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f32); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationDWCFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine( + datasets::Large3DShapes(), + framework::dataset::make("DataType", { DataType::F32 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f32); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f32); +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLFuseBatchNormalizationDWCFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine( + datasets::Small3DShapes(), + framework::dataset::make("DataType", { DataType::F16 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f16); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationDWCFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine( + datasets::Large3DShapes(), + framework::dataset::make("DataType", { DataType::F16 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f16); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f16); +} + +TEST_SUITE_END() // FP16 +TEST_SUITE_END() // Float +TEST_SUITE_END() // DepthwiseConvolution + TEST_SUITE_END() // FuseBatchNormalization TEST_SUITE_END() // CL } // namespace validation diff --git a/tests/validation/fixtures/FuseBatchNormalizationFixture.h b/tests/validation/fixtures/FuseBatchNormalizationFixture.h index 864d627ed7..2e76792b3e 100644 --- a/tests/validation/fixtures/FuseBatchNormalizationFixture.h +++ b/tests/validation/fixtures/FuseBatchNormalizationFixture.h @@ -90,9 +90,12 @@ protected: auto w_fused_to_use = in_place_w ? nullptr : &w_fused; auto b_fused_to_use = in_place_b ? nullptr : &b_fused; + const FuseBatchNormalizationType fuse_bn_type = dims_weights == 3 ? + FuseBatchNormalizationType::DEPTHWISECONVOLUTION : + FuseBatchNormalizationType::CONVOLUTION; // Create and configure function FunctionType fuse_batch_normalization; - fuse_batch_normalization.configure(&w, &mean, &var, w_fused_to_use, b_fused_to_use, b_to_use, beta_to_use, gamma_to_use, _epsilon); + fuse_batch_normalization.configure(&w, &mean, &var, w_fused_to_use, b_fused_to_use, b_to_use, beta_to_use, gamma_to_use, _epsilon, fuse_bn_type); ARM_COMPUTE_EXPECT(w.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS); -- cgit v1.2.1