diff options
Diffstat (limited to 'tests/validation/CL/BatchNormalizationLayer.cpp')
-rw-r--r-- | tests/validation/CL/BatchNormalizationLayer.cpp | 41 |
1 files changed, 28 insertions, 13 deletions
diff --git a/tests/validation/CL/BatchNormalizationLayer.cpp b/tests/validation/CL/BatchNormalizationLayer.cpp index ef535153f2..8c143060cb 100644 --- a/tests/validation/CL/BatchNormalizationLayer.cpp +++ b/tests/validation/CL/BatchNormalizationLayer.cpp @@ -61,8 +61,11 @@ TEST_SUITE(BatchNormalizationLayer) template <typename T> using CLBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture<CLTensor, CLAccessor, CLBatchNormalizationLayer, T>; -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::RandomBatchNormalizationLayerDataset(), framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F16, DataType::F32 })), - shape0, shape1, epsilon, dt) +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), + framework::dataset::make("UseGamma", { false, true }))), + framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F16, DataType::F32 })), + shape0, shape1, epsilon, use_gamma, use_beta, dt) { // Set fixed point position data type allowed const int fixed_point_position = (arm_compute::is_data_type_fixed_point(dt)) ? 3 : 0; @@ -77,7 +80,9 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran // Create and Configure function CLBatchNormalizationLayer norm; - norm.configure(&src, &dst, &mean, &var, &beta, &gamma, epsilon); + CLTensor *beta_ptr = use_beta ? &beta : nullptr; + CLTensor *gamma_ptr = use_gamma ? &gamma : nullptr; + norm.configure(&src, &dst, &mean, &var, beta_ptr, gamma_ptr, epsilon); // Validate valid region const ValidRegion valid_region = shape_to_valid_region(shape0); @@ -150,7 +155,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), + framework::dataset::make("UseGamma", { false, true }))), act_infos), framework::dataset::make("DataType", DataType::F32))) { @@ -160,7 +167,9 @@ FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framewor TEST_SUITE_END() TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), + framework::dataset::make("UseGamma", { false, true }))), framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f))), framework::dataset::make("DataType", DataType::F16))) { @@ -175,10 +184,13 @@ template <typename T> using CLBatchNormalizationLayerFixedPointFixture = BatchNormalizationLayerValidationFixedPointFixture<CLTensor, CLAccessor, CLBatchNormalizationLayer, T>; TEST_SUITE(QS8) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("DataType", DataType::QS8)), - framework::dataset::make("FractionalBits", 1, 6))) +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("UseBeta", false)), + framework::dataset::make("UseGamma", false)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("DataType", DataType::QS8)), + framework::dataset::make("FractionalBits", 1, 6))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_qs8, 0); @@ -186,10 +198,13 @@ FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture<int8_t TEST_SUITE_END() TEST_SUITE(QS16) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("DataType", DataType::QS16)), - framework::dataset::make("FractionalBits", 1, 14))) +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("UseBeta", false)), + framework::dataset::make("UseGamma", false)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("DataType", DataType::QS16)), + framework::dataset::make("FractionalBits", 1, 14))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_qs16, 0); |