From 1167487ea8e54a76d0a3625e0aa84e2ad9ffd317 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Wed, 7 Feb 2018 15:38:12 +0000 Subject: COMPMID-897 Merge batch normalization with bounded relu Change-Id: I9a607fe620f795cdea1a99fdd3f5f8c2fc76f980 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/119234 Tested-by: Jenkins Reviewed-by: Gian Marco Iodice Reviewed-by: Georgios Pinitas --- tests/validation/CL/BatchNormalizationLayer.cpp | 49 ++++++++++++++++++---- .../GLES_COMPUTE/BatchNormalizationLayer.cpp | 14 +++++-- tests/validation/NEON/BatchNormalizationLayer.cpp | 45 ++++++++++++++++---- .../fixtures/BatchNormalizationLayerFixture.h | 20 ++++----- .../reference/BatchNormalizationLayer.cpp | 24 +++++++---- .../validation/reference/BatchNormalizationLayer.h | 8 ++-- 6 files changed, 119 insertions(+), 41 deletions(-) (limited to 'tests/validation') diff --git a/tests/validation/CL/BatchNormalizationLayer.cpp b/tests/validation/CL/BatchNormalizationLayer.cpp index 30dd70a66a..ef535153f2 100644 --- a/tests/validation/CL/BatchNormalizationLayer.cpp +++ b/tests/validation/CL/BatchNormalizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -47,6 +47,12 @@ constexpr AbsoluteTolerance tolerance_f32(0.00001f); /**< Tolerance value constexpr AbsoluteTolerance tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ constexpr AbsoluteTolerance tolerance_qs8(3.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */ constexpr AbsoluteTolerance tolerance_qs16(6.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS16 */ +const auto act_infos = framework::dataset::make("ActivationInfo", +{ + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), +}); } // namespace TEST_SUITE(CL) @@ -80,13 +86,16 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran // *INDENT-OFF* // clang-format off -DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Window shrink TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching data types TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching data types TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid mean/var/beta/gamma shape TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point position + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Fused activation with fixed point not supported + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Unsupported fused activation + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Fused activation's a < b TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), }), @@ -96,6 +105,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), TensorInfo(), })), @@ -106,16 +118,31 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( TensorInfo(TensorShape(5U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), + })), + framework::dataset::make("ActivationLayerInfo",{ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f, 2.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f, 2.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 2.f, 6.f), + ActivationLayerInfo(), + ActivationLayerInfo(), })), - framework::dataset::make("Expected", { true, false, false, false, false, false, true, true})), - input_info, output_info, mvbg_info, expected) + framework::dataset::make("Expected", { true, false, false, false, false, false, false, false, false, true, true})), + input_info, output_info, mvbg_info, act_info, expected) { const auto &mean_info = mvbg_info; const auto &var_info = mvbg_info; const auto &beta_info = mvbg_info; const auto &gamma_info = mvbg_info; - bool has_error = bool(CLBatchNormalizationLayer::validate(&input_info.clone()->set_is_resizable(false), (output_info.total_size() == 0) ? nullptr : &output_info.clone()->set_is_resizable(false), &mean_info.clone()->set_is_resizable(false), &var_info.clone()->set_is_resizable(false), &beta_info.clone()->set_is_resizable(false), &gamma_info.clone()->set_is_resizable(false), 1.f)); + bool has_error = bool(CLBatchNormalizationLayer::validate(&input_info.clone()->set_is_resizable(false), (output_info.total_size() == 0) ? nullptr : &output_info.clone()->set_is_resizable(false), &mean_info.clone()->set_is_resizable(false), &var_info.clone()->set_is_resizable(false), &beta_info.clone()->set_is_resizable(false), &gamma_info.clone()->set_is_resizable(false), 1.f, act_info)); ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS); } // clang-format on @@ -123,7 +150,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + act_infos), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -132,7 +160,8 @@ FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture, framewor TEST_SUITE_END() TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f))), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -146,7 +175,8 @@ template using CLBatchNormalizationLayerFixedPointFixture = BatchNormalizationLayerValidationFixedPointFixture; TEST_SUITE(QS8) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), framework::dataset::make("DataType", DataType::QS8)), framework::dataset::make("FractionalBits", 1, 6))) { @@ -156,7 +186,8 @@ FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), framework::dataset::make("DataType", DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { diff --git a/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp b/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp index a82149bdcc..d817fc0e67 100644 --- a/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp +++ b/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -45,6 +45,12 @@ namespace { constexpr AbsoluteTolerance tolerance_f(0.00001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ constexpr AbsoluteTolerance tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ +const auto act_infos = framework::dataset::make("ActivationInfo", +{ + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), +}); } // namespace TEST_SUITE(GC) @@ -78,7 +84,8 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran TEST_SUITE(Float) TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + act_infos), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -87,7 +94,8 @@ FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework TEST_SUITE_END() TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + act_infos), framework::dataset::make("DataType", DataType::F32))) { // Validate output diff --git a/tests/validation/NEON/BatchNormalizationLayer.cpp b/tests/validation/NEON/BatchNormalizationLayer.cpp index dfa32bbb07..3501c359db 100644 --- a/tests/validation/NEON/BatchNormalizationLayer.cpp +++ b/tests/validation/NEON/BatchNormalizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -49,6 +49,12 @@ constexpr AbsoluteTolerance tolerance_f16(0.01f); /**< Tolerance value fo #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ constexpr AbsoluteTolerance tolerance_qs8(3.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */ constexpr AbsoluteTolerance tolerance_qs16(6.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS16 */ +const auto act_infos = framework::dataset::make("ActivationInfo", +{ + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), +}); } // namespace TEST_SUITE(NEON) @@ -82,13 +88,15 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran // *INDENT-OFF* // clang-format off -DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Window shrink TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching data types TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching data types TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid mean/var/beta/gamma shape TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point position + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Fused activation with fixed point not supported + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Fused activation's a < b TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), }), @@ -98,6 +106,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), TensorInfo(), })), @@ -108,10 +118,23 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( TensorInfo(TensorShape(5U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), })), - framework::dataset::make("Expected", { true, false, false, false, false, false, true, true})), - input_info, output_info, mvbg_info, expected) + framework::dataset::make("ActivationLayerInfo",{ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f, 2.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f, 2.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 2.f, 6.f), + ActivationLayerInfo(), + ActivationLayerInfo(), + })), + framework::dataset::make("Expected", { true, false, false, false, false, false, false, false, true, true})), + input_info, output_info, mvbg_info, act_info, expected) { const auto &mean_info = mvbg_info; const auto &var_info = mvbg_info; @@ -120,14 +143,15 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( bool has_error = bool(NEBatchNormalizationLayer::validate( &input_info.clone()->set_is_resizable(false), output_info.total_size() ? &output_info.clone()->set_is_resizable(false) : nullptr, &mean_info.clone()->set_is_resizable(false), &var_info.clone()->set_is_resizable(false), - &beta_info.clone()->set_is_resizable(false), &gamma_info.clone()->set_is_resizable(false), 1.f)); + &beta_info.clone()->set_is_resizable(false), &gamma_info.clone()->set_is_resizable(false), 1.f, act_info)); ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS); } // clang-format on // *INDENT-ON* TEST_SUITE(Float) -FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + act_infos), framework::dataset::make("DataType", DataType::F32))) { // Validate output @@ -137,7 +161,8 @@ TEST_SUITE_END() #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(Float16) -FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + act_infos), framework::dataset::make("DataType", DataType::F16))) { // Validate output @@ -151,7 +176,8 @@ template using NEBatchNormalizationLayerFixedPointFixture = BatchNormalizationLayerValidationFixedPointFixture; TEST_SUITE(QS8) -FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), framework::dataset::make("DataType", DataType::QS8)), framework::dataset::make("FractionalBits", 1, 6))) { @@ -161,7 +187,8 @@ FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), framework::dataset::make("DataType", DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { diff --git a/tests/validation/fixtures/BatchNormalizationLayerFixture.h b/tests/validation/fixtures/BatchNormalizationLayerFixture.h index 298c9ca411..e02c619249 100644 --- a/tests/validation/fixtures/BatchNormalizationLayerFixture.h +++ b/tests/validation/fixtures/BatchNormalizationLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -45,12 +45,12 @@ class BatchNormalizationLayerValidationFixedPointFixture : public framework::Fix { public: template - void setup(TensorShape shape0, TensorShape shape1, float epsilon, DataType dt, int fractional_bits) + void setup(TensorShape shape0, TensorShape shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fractional_bits) { _fractional_bits = fractional_bits; _data_type = dt; - _target = compute_target(shape0, shape1, epsilon, dt, fractional_bits); - _reference = compute_reference(shape0, shape1, epsilon, dt, fractional_bits); + _target = compute_target(shape0, shape1, epsilon, act_info, dt, fractional_bits); + _reference = compute_reference(shape0, shape1, epsilon, act_info, dt, fractional_bits); } protected: @@ -85,7 +85,7 @@ protected: } } - TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, float epsilon, DataType dt, int fixed_point_position) + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fixed_point_position) { // Create tensors TensorType src = create_tensor(shape0, dt, 1, fixed_point_position); @@ -97,7 +97,7 @@ protected: // Create and configure function FunctionType norm; - norm.configure(&src, &dst, &mean, &var, &beta, &gamma, epsilon); + norm.configure(&src, &dst, &mean, &var, &beta, &gamma, epsilon, act_info); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -130,7 +130,7 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape0, const TensorShape &shape1, float epsilon, DataType dt, int fixed_point_position) + SimpleTensor compute_reference(const TensorShape &shape0, const TensorShape &shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fixed_point_position) { // Create reference SimpleTensor ref_src{ shape0, dt, 1, fixed_point_position }; @@ -142,7 +142,7 @@ protected: // Fill reference fill(ref_src, ref_mean, ref_var, ref_beta, ref_gamma); - return reference::batch_normalization_layer(ref_src, ref_mean, ref_var, ref_beta, ref_gamma, epsilon, fixed_point_position); + return reference::batch_normalization_layer(ref_src, ref_mean, ref_var, ref_beta, ref_gamma, epsilon, act_info, fixed_point_position); } TensorType _target{}; @@ -156,9 +156,9 @@ class BatchNormalizationLayerValidationFixture : public BatchNormalizationLayerV { public: template - void setup(TensorShape shape0, TensorShape shape1, float epsilon, DataType dt) + void setup(TensorShape shape0, TensorShape shape1, float epsilon, ActivationLayerInfo act_info, DataType dt) { - BatchNormalizationLayerValidationFixedPointFixture::setup(shape0, shape1, epsilon, dt, 0); + BatchNormalizationLayerValidationFixedPointFixture::setup(shape0, shape1, epsilon, act_info, dt, 0); } }; } // namespace validation diff --git a/tests/validation/reference/BatchNormalizationLayer.cpp b/tests/validation/reference/BatchNormalizationLayer.cpp index e4446d1694..a9d9f0320d 100644 --- a/tests/validation/reference/BatchNormalizationLayer.cpp +++ b/tests/validation/reference/BatchNormalizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -23,6 +23,8 @@ */ #include "BatchNormalizationLayer.h" +#include "ActivationLayer.h" + #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" @@ -37,8 +39,9 @@ namespace reference // Batch Normalization Layer for fixed point type template ::value, int>::type *> SimpleTensor batch_normalization_layer(const SimpleTensor &src, const SimpleTensor &mean, const SimpleTensor &var, const SimpleTensor &beta, const SimpleTensor &gamma, float epsilon, - int fixed_point_position) + ActivationLayerInfo act_info, int fixed_point_position) { + ARM_COMPUTE_UNUSED(act_info); SimpleTensor result(src.shape(), src.data_type()); const auto cols = static_cast(src.shape()[0]); @@ -79,7 +82,7 @@ SimpleTensor batch_normalization_layer(const SimpleTensor &src, const Simp // Batch Normalization Layer for floating point type template ::value, int>::type *> SimpleTensor batch_normalization_layer(const SimpleTensor &src, const SimpleTensor &mean, const SimpleTensor &var, const SimpleTensor &beta, const SimpleTensor &gamma, float epsilon, - int fixed_point_position) + ActivationLayerInfo act_info, int fixed_point_position) { ARM_COMPUTE_UNUSED(fixed_point_position); @@ -103,21 +106,28 @@ SimpleTensor batch_normalization_layer(const SimpleTensor &src, const Simp const float numerator = src[pos] - mean[i]; const float x_bar = numerator / denominator; result[pos] = beta[i] + x_bar * gamma[i]; + ; } } } } + + if(act_info.enabled()) + { + result = activation_layer(result, act_info); + } + return result; } template SimpleTensor batch_normalization_layer(const SimpleTensor &src, const SimpleTensor &mean, const SimpleTensor &var, const SimpleTensor &beta, - const SimpleTensor &gamma, float epsilon, int fixed_point_position); + const SimpleTensor &gamma, float epsilon, ActivationLayerInfo act_info, int fixed_point_position); template SimpleTensor batch_normalization_layer(const SimpleTensor &src, const SimpleTensor &mean, const SimpleTensor &var, const SimpleTensor &beta, - const SimpleTensor &gamma, float epsilon, int fixed_point_position); + const SimpleTensor &gamma, float epsilon, ActivationLayerInfo act_info, int fixed_point_position); template SimpleTensor batch_normalization_layer(const SimpleTensor &src, const SimpleTensor &mean, const SimpleTensor &var, const SimpleTensor &beta, - const SimpleTensor &gamma, float epsilon, int fixed_point_position); + const SimpleTensor &gamma, float epsilon, ActivationLayerInfo act_info, int fixed_point_position); template SimpleTensor batch_normalization_layer(const SimpleTensor &src, const SimpleTensor &mean, const SimpleTensor &var, const SimpleTensor &beta, - const SimpleTensor &gamma, float epsilon, int fixed_point_position); + const SimpleTensor &gamma, float epsilon, ActivationLayerInfo act_info, int fixed_point_position); } // namespace reference } // namespace validation diff --git a/tests/validation/reference/BatchNormalizationLayer.h b/tests/validation/reference/BatchNormalizationLayer.h index 1a554adf7e..329909dab4 100644 --- a/tests/validation/reference/BatchNormalizationLayer.h +++ b/tests/validation/reference/BatchNormalizationLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -37,11 +37,13 @@ namespace reference { template ::value, int>::type * = nullptr> SimpleTensor batch_normalization_layer(const SimpleTensor &src, const SimpleTensor &mean, const SimpleTensor &var, const SimpleTensor &beta, const SimpleTensor &gamma, float epsilon, - int fixed_point_position); + ActivationLayerInfo act_info, + int fixed_point_position); template ::value, int>::type * = nullptr> SimpleTensor batch_normalization_layer(const SimpleTensor &src, const SimpleTensor &mean, const SimpleTensor &var, const SimpleTensor &beta, const SimpleTensor &gamma, float epsilon, - int fixed_point_position); + ActivationLayerInfo act_info, + int fixed_point_position); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1