diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2018-03-02 09:43:54 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:37 +0000 |
commit | 4d33630096c769dd43716dd5607f151e3d5abef7 (patch) | |
tree | 762897c2acac9553c0dad688d0c21842c8edff16 /tests/validation/GLES_COMPUTE | |
parent | 1cd41495153c4e89d6195b42f870967339c1a13b (diff) | |
download | ComputeLibrary-4d33630096c769dd43716dd5607f151e3d5abef7.tar.gz |
COMPMID-987: Make beta and gamma optional in BatchNormalization
Currently we have beta and gamma compulsory in Batch normalization. There are
network that might not need one or both of those. Thus these should be optional
with beta(offset) defaulting to zero and gamma(scale) to 1. Will also reduce
some memory requirements.
Change-Id: I15bf1ec14b814be2acebf1be1a4fba9c4fbd3190
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123237
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/GLES_COMPUTE')
-rw-r--r-- | tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp b/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp index d817fc0e67..2dbb0e0fbb 100644 --- a/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp +++ b/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp @@ -59,8 +59,11 @@ TEST_SUITE(BatchNormalizationLayer) template <typename T> using GCBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture<GCTensor, GCAccessor, GCBatchNormalizationLayer, T>; -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::RandomBatchNormalizationLayerDataset(), framework::dataset::make("DataType", { DataType::F32 })), - shape0, shape1, epsilon, dt) +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), + framework::dataset::make("UseGamma", { false, true }))), + framework::dataset::make("DataType", { DataType::F32 })), + shape0, shape1, epsilon, use_beta, use_gamma, dt) { // Set fixed point position data type allowed int fixed_point_position = (arm_compute::is_data_type_fixed_point(dt)) ? 3 : 0; @@ -75,7 +78,9 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran // Create and Configure function GCBatchNormalizationLayer norm; - norm.configure(&src, &dst, &mean, &var, &beta, &gamma, epsilon); + GCTensor *beta_ptr = use_beta ? &beta : nullptr; + GCTensor *gamma_ptr = use_gamma ? &gamma : nullptr; + norm.configure(&src, &dst, &mean, &var, beta_ptr, gamma_ptr, epsilon); // Validate valid region const ValidRegion valid_region = shape_to_valid_region(shape0); @@ -84,7 +89,9 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran TEST_SUITE(Float) TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), + framework::dataset::make("UseGamma", { false, true }))), act_infos), framework::dataset::make("DataType", DataType::F16))) { @@ -94,7 +101,9 @@ FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<half>, framework TEST_SUITE_END() TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), + framework::dataset::make("UseGamma", { false, true }))), act_infos), framework::dataset::make("DataType", DataType::F32))) { |