diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2019-05-28 11:44:41 +0100 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2019-06-13 16:01:42 +0000 |
commit | 2732cca12bac29e1515cee1db5005c73893c61b4 (patch) | |
tree | 050d4c20b51b2b642be21512f9b4a900e18ce88c /tests/validation/CL/FuseBatchNormalization.cpp | |
parent | b3a0a60d0b570c58d84324059abb5caceae2561c (diff) | |
download | ComputeLibrary-2732cca12bac29e1515cee1db5005c73893c61b4.tar.gz |
COMPMID-2244: Extend CLFuseBatchNormalization to support DepthwiseConvolution weights
Change-Id: I7d1907f35cc4899379073759be2f7cce24e51e9d
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1327
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL/FuseBatchNormalization.cpp')
-rw-r--r-- | tests/validation/CL/FuseBatchNormalization.cpp | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/tests/validation/CL/FuseBatchNormalization.cpp b/tests/validation/CL/FuseBatchNormalization.cpp index 92d63c0c3d..35414b765a 100644 --- a/tests/validation/CL/FuseBatchNormalization.cpp +++ b/tests/validation/CL/FuseBatchNormalization.cpp @@ -44,6 +44,8 @@ AbsoluteTolerance<float> absolute_tolerance_f16(0.2f); template <typename T> using CLFuseBatchNormalizationConvFixture = FuseBatchNormalizationFixture<CLTensor, CLAccessor, CLFuseBatchNormalization, 4, T>; +template <typename T> +using CLFuseBatchNormalizationDWCFixture = FuseBatchNormalizationFixture<CLTensor, CLAccessor, CLFuseBatchNormalization, 3, T>; // *INDENT-OFF* // clang-format off @@ -140,6 +142,77 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationConvFixture<half>, fram TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float TEST_SUITE_END() // Convolution + +TEST_SUITE(DepthwiseConvolution) +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLFuseBatchNormalizationDWCFixture<float>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine( + datasets::Small3DShapes(), + framework::dataset::make("DataType", { DataType::F32 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f32); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationDWCFixture<float>, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine( + datasets::Large3DShapes(), + framework::dataset::make("DataType", { DataType::F32 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f32); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f32); +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLFuseBatchNormalizationDWCFixture<half>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine( + datasets::Small3DShapes(), + framework::dataset::make("DataType", { DataType::F16 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f16); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLFuseBatchNormalizationDWCFixture<half>, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine( + datasets::Large3DShapes(), + framework::dataset::make("DataType", { DataType::F16 })), + data_layout_values), + in_place_values), + with_bias_values), + with_gamma_values), + with_beta_values)) +{ + // Validate outputs + validate(CLAccessor(_target_w), _reference_w, absolute_tolerance_f16); + validate(CLAccessor(_target_b), _reference_b, absolute_tolerance_f16); +} + +TEST_SUITE_END() // FP16 +TEST_SUITE_END() // Float +TEST_SUITE_END() // DepthwiseConvolution + TEST_SUITE_END() // FuseBatchNormalization TEST_SUITE_END() // CL } // namespace validation |