diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2019-05-28 11:44:41 +0100 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2019-06-13 16:01:42 +0000 |
commit | 2732cca12bac29e1515cee1db5005c73893c61b4 (patch) | |
tree | 050d4c20b51b2b642be21512f9b4a900e18ce88c /tests | |
parent | b3a0a60d0b570c58d84324059abb5caceae2561c (diff) | |
download | ComputeLibrary-2732cca12bac29e1515cee1db5005c73893c61b4.tar.gz |
COMPMID-2244: Extend CLFuseBatchNormalization to support DepthwiseConvolution weights
Change-Id: I7d1907f35cc4899379073759be2f7cce24e51e9d
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1327
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/validation/CL/FuseBatchNormalization.cpp | 73 | ||||
-rw-r--r-- | tests/validation/fixtures/FuseBatchNormalizationFixture.h | 5 |
2 files changed, 77 insertions, 1 deletions
diff --git a/tests/validation/CL/FuseBatchNormalization.cpp b/tests/validation/CL/FuseBatchNormalization.cpp index 92d63c0c3d..35414b765a 100644 --- a/tests/validation/CL/FuseBatchNormalization.cpp +++ b/tests/validation/CL/FuseBatchNormalization.cpp @@ -44,6 +44,8 @@ AbsoluteTolerance<float> absolute_tolerance_f16(0.2f); template <typename T> using CLFuseBatchNormalizationConvFixture = FuseBatchNormalizationFixture<CLTensor, CLAccessor, CLFuseBatchNormalization, 4, T>; +template <typename T> +using CLFuseBatchNormalizationDWCFixture = FuseBatchNormalizationFixture<CLTensor, CLAccessor, CLFuseBatchNormalization, 3, T>; // *INDENT-OFF* // clang-format off @@ -140,6 +142,77 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationConvFixture<half>, fram TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float TEST_SUITE_END() // Convolution + +TEST_SUITE(DepthwiseConvolution) +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLFuseBatchNormalizationDWCFixture<float>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine( + datasets::Small3DShapes(), + framework::dataset::make("DataType", { DataType::F32 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f32); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationDWCFixture<float>, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine( + datasets::Large3DShapes(), + framework::dataset::make("DataType", { DataType::F32 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f32); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f32); +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLFuseBatchNormalizationDWCFixture<half>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine( + datasets::Small3DShapes(), + framework::dataset::make("DataType", { DataType::F16 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f16); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationDWCFixture<half>, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine( + datasets::Large3DShapes(), + framework::dataset::make("DataType", { DataType::F16 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f16); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f16); +} + +TEST_SUITE_END() // FP16 +TEST_SUITE_END() // Float +TEST_SUITE_END() // DepthwiseConvolution + TEST_SUITE_END() // FuseBatchNormalization TEST_SUITE_END() // CL } // namespace validation diff --git a/tests/validation/fixtures/FuseBatchNormalizationFixture.h b/tests/validation/fixtures/FuseBatchNormalizationFixture.h index 864d627ed7..2e76792b3e 100644 --- a/tests/validation/fixtures/FuseBatchNormalizationFixture.h +++ b/tests/validation/fixtures/FuseBatchNormalizationFixture.h @@ -90,9 +90,12 @@ protected: auto w_fused_to_use = in_place_w ? nullptr : &w_fused; auto b_fused_to_use = in_place_b ? nullptr : &b_fused; + const FuseBatchNormalizationType fuse_bn_type = dims_weights == 3 ? + FuseBatchNormalizationType::DEPTHWISECONVOLUTION : + FuseBatchNormalizationType::CONVOLUTION; // Create and configure function FunctionType fuse_batch_normalization; - fuse_batch_normalization.configure(&w, &mean, &var, w_fused_to_use, b_fused_to_use, b_to_use, beta_to_use, gamma_to_use, _epsilon); + fuse_batch_normalization.configure(&w, &mean, &var, w_fused_to_use, b_fused_to_use, b_to_use, beta_to_use, gamma_to_use, _epsilon, fuse_bn_type); ARM_COMPUTE_EXPECT(w.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS); |