diff options
Diffstat (limited to 'tests/validation')
5 files changed, 122 insertions, 72 deletions
diff --git a/tests/validation/CL/BatchNormalizationLayer.cpp b/tests/validation/CL/BatchNormalizationLayer.cpp index 8c143060cb..6190e67dba 100644 --- a/tests/validation/CL/BatchNormalizationLayer.cpp +++ b/tests/validation/CL/BatchNormalizationLayer.cpp @@ -32,6 +32,7 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Helpers.h" #include "tests/validation/Validation.h" #include "tests/validation/fixtures/BatchNormalizationLayerFixture.h" @@ -61,18 +62,25 @@ TEST_SUITE(BatchNormalizationLayer) template <typename T> using CLBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture<CLTensor, CLAccessor, CLBatchNormalizationLayer, T>; -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - combine(framework::dataset::make("UseBeta", { false, true }), - framework::dataset::make("UseGamma", { false, true }))), - framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F16, DataType::F32 })), - shape0, shape1, epsilon, use_gamma, use_beta, dt) +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), + framework::dataset::make("UseGamma", { false, true }))), + framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F16, DataType::F32 })), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), + shape0, shape1, epsilon, use_gamma, use_beta, dt, data_layout) { // Set fixed point position data type allowed const int fixed_point_position = (arm_compute::is_data_type_fixed_point(dt)) ? 3 : 0; + TensorShape src_dst_shapes = shape0; + if(data_layout == DataLayout::NHWC) + { + permute(src_dst_shapes, PermutationVector(2U, 0U, 1U)); + } + // Create tensors - CLTensor src = create_tensor<CLTensor>(shape0, dt, 1, fixed_point_position); - CLTensor dst = create_tensor<CLTensor>(shape0, dt, 1, fixed_point_position); + CLTensor src = create_tensor<CLTensor>(src_dst_shapes, dt, 1, fixed_point_position, QuantizationInfo(), data_layout); + CLTensor dst = create_tensor<CLTensor>(src_dst_shapes, dt, 1, fixed_point_position, QuantizationInfo(), data_layout); CLTensor mean = create_tensor<CLTensor>(shape1, dt, 1, fixed_point_position); CLTensor var = create_tensor<CLTensor>(shape1, dt, 1, fixed_point_position); CLTensor beta = create_tensor<CLTensor>(shape1, dt, 1, fixed_point_position); @@ -85,7 +93,7 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(datas norm.configure(&src, &dst, &mean, &var, beta_ptr, gamma_ptr, epsilon); // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(shape0); + const ValidRegion valid_region = shape_to_valid_region(src_dst_shapes); validate(dst.info()->valid_region(), valid_region); } @@ -155,11 +163,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), act_infos), - framework::dataset::make("DataType", DataType::F32))) + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW }))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f32, 0); @@ -167,11 +176,12 @@ FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framewor TEST_SUITE_END() TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f))), - framework::dataset::make("DataType", DataType::F16))) + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW }))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_f16, 0); @@ -185,11 +195,12 @@ using CLBatchNormalizationLayerFixedPointFixture = BatchNormalizationLayerValida TEST_SUITE(QS8) FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - framework::dataset::make("UseBeta", false)), - framework::dataset::make("UseGamma", false)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("DataType", DataType::QS8)), + combine(combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("UseBeta", false)), + framework::dataset::make("UseGamma", false)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("DataType", DataType::QS8)), + framework::dataset::make("DataLayout", DataLayout::NCHW)), framework::dataset::make("FractionalBits", 1, 6))) { // Validate output @@ -199,11 +210,12 @@ TEST_SUITE_END() TEST_SUITE(QS16) FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - framework::dataset::make("UseBeta", false)), - framework::dataset::make("UseGamma", false)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("DataType", DataType::QS16)), + combine(combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("UseBeta", false)), + framework::dataset::make("UseGamma", false)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("DataType", DataType::QS16)), + framework::dataset::make("DataLayout", DataLayout::NCHW)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output diff --git a/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp b/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp index 2dbb0e0fbb..d22f1e9958 100644 --- a/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp +++ b/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp @@ -32,6 +32,7 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Helpers.h" #include "tests/validation/Validation.h" #include "tests/validation/fixtures/BatchNormalizationLayerFixture.h" @@ -59,18 +60,25 @@ TEST_SUITE(BatchNormalizationLayer) template <typename T> using GCBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture<GCTensor, GCAccessor, GCBatchNormalizationLayer, T>; -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - combine(framework::dataset::make("UseBeta", { false, true }), - framework::dataset::make("UseGamma", { false, true }))), - framework::dataset::make("DataType", { DataType::F32 })), - shape0, shape1, epsilon, use_beta, use_gamma, dt) +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), + framework::dataset::make("UseGamma", { false, true }))), + framework::dataset::make("DataType", { DataType::F32 })), + framework::dataset::make("DataLayout", { DataLayout::NCHW })), + shape0, shape1, epsilon, use_beta, use_gamma, dt, data_layout) { // Set fixed point position data type allowed int fixed_point_position = (arm_compute::is_data_type_fixed_point(dt)) ? 3 : 0; + TensorShape src_dst_shapes = shape0; + if(data_layout == DataLayout::NHWC) + { + permute(src_dst_shapes, PermutationVector(2U, 0U, 1U)); + } + // Create tensors - GCTensor src = create_tensor<GCTensor>(shape0, dt, 1, fixed_point_position); - GCTensor dst = create_tensor<GCTensor>(shape0, dt, 1, fixed_point_position); + GCTensor src = create_tensor<GCTensor>(src_dst_shapes, dt, 1, fixed_point_position, QuantizationInfo(), data_layout); + GCTensor dst = create_tensor<GCTensor>(src_dst_shapes, dt, 1, fixed_point_position, QuantizationInfo(), data_layout); GCTensor mean = create_tensor<GCTensor>(shape1, dt, 1, fixed_point_position); GCTensor var = create_tensor<GCTensor>(shape1, dt, 1, fixed_point_position); GCTensor beta = create_tensor<GCTensor>(shape1, dt, 1, fixed_point_position); @@ -83,17 +91,18 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(datas norm.configure(&src, &dst, &mean, &var, beta_ptr, gamma_ptr, epsilon); // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(shape0); + const ValidRegion valid_region = shape_to_valid_region(src_dst_shapes); validate(dst.info()->valid_region(), valid_region); } TEST_SUITE(Float) TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), act_infos), - framework::dataset::make("DataType", DataType::F16))) + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW }))) { // Validate output validate(GCAccessor(_target), _reference, tolerance_f16, 0); @@ -101,11 +110,12 @@ FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<half>, framework TEST_SUITE_END() TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), act_infos), - framework::dataset::make("DataType", DataType::F32))) + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW }))) { // Validate output validate(GCAccessor(_target), _reference, tolerance_f, 0); diff --git a/tests/validation/NEON/BatchNormalizationLayer.cpp b/tests/validation/NEON/BatchNormalizationLayer.cpp index 7bf1f2633e..53fd0163ff 100644 --- a/tests/validation/NEON/BatchNormalizationLayer.cpp +++ b/tests/validation/NEON/BatchNormalizationLayer.cpp @@ -32,6 +32,7 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Macros.h" #include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Helpers.h" #include "tests/validation/Validation.h" #include "tests/validation/fixtures/BatchNormalizationLayerFixture.h" @@ -63,17 +64,24 @@ TEST_SUITE(BatchNormalizationLayer) template <typename T> using NEBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture<Tensor, Accessor, NEBatchNormalizationLayer, T>; -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), - framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F32 })), - shape0, shape1, epsilon, use_beta, use_gamma, dt) +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), + framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F32 })), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + shape0, shape1, epsilon, use_beta, use_gamma, dt, data_layout) { // Set fixed point position data type allowed const int fixed_point_position = (arm_compute::is_data_type_fixed_point(dt)) ? 3 : 0; + TensorShape src_dst_shapes = shape0; + if(data_layout == DataLayout::NHWC) + { + permute(src_dst_shapes, PermutationVector(2U, 0U, 1U)); + } + // Create tensors - Tensor src = create_tensor<Tensor>(shape0, dt, 1, fixed_point_position); - Tensor dst = create_tensor<Tensor>(shape0, dt, 1, fixed_point_position); + Tensor src = create_tensor<Tensor>(src_dst_shapes, dt, 1, fixed_point_position, QuantizationInfo(), data_layout); + Tensor dst = create_tensor<Tensor>(src_dst_shapes, dt, 1, fixed_point_position, QuantizationInfo(), data_layout); Tensor mean = create_tensor<Tensor>(shape1, dt, 1, fixed_point_position); Tensor var = create_tensor<Tensor>(shape1, dt, 1, fixed_point_position); Tensor beta = create_tensor<Tensor>(shape1, dt, 1, fixed_point_position); @@ -86,7 +94,7 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(datas norm.configure(&src, &dst, &mean, &var, beta_ptr, gamma_ptr, epsilon); // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(shape0); + const ValidRegion valid_region = shape_to_valid_region(src_dst_shapes); validate(dst.info()->valid_region(), valid_region); } @@ -154,11 +162,13 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( // *INDENT-ON* TEST_SUITE(Float) -FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), act_infos), - framework::dataset::make("DataType", DataType::F32))) + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) { // Validate output validate(Accessor(_target), _reference, tolerance_f32, 0); @@ -166,18 +176,20 @@ FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture<float>, framewor TEST_SUITE_END() #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -TEST_SUITE(Float16) -FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), combine(framework::dataset::make("UseBeta", { false, true }), framework::dataset::make("UseGamma", { false, true }))), framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("DataType", DataType::F16))) + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) { // Validate output validate(Accessor(_target), _reference, tolerance_f16, 0); } TEST_SUITE_END() #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +TEST_SUITE_END() TEST_SUITE(Quantized) template <typename T> @@ -185,11 +197,12 @@ using NEBatchNormalizationLayerFixedPointFixture = BatchNormalizationLayerValida TEST_SUITE(QS8) FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - framework::dataset::make("UseBeta", false)), - framework::dataset::make("UseGamma", false)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("DataType", DataType::QS8)), + combine(combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("UseBeta", false)), + framework::dataset::make("UseGamma", false)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("DataType", DataType::QS8)), + framework::dataset::make("DataLayout", DataLayout::NCHW)), framework::dataset::make("FractionalBits", 1, 6))) { // Validate output @@ -199,11 +212,12 @@ TEST_SUITE_END() TEST_SUITE(QS16) FIXTURE_DATA_TEST_CASE(Random, NEBatchNormalizationLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - framework::dataset::make("UseBeta", false)), - framework::dataset::make("UseGamma", false)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo())), - framework::dataset::make("DataType", DataType::QS16)), + combine(combine(combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("UseBeta", false)), + framework::dataset::make("UseGamma", false)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo())), + framework::dataset::make("DataType", DataType::QS16)), + framework::dataset::make("DataLayout", DataLayout::NCHW)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output diff --git a/tests/validation/fixtures/BatchNormalizationLayerFixture.h b/tests/validation/fixtures/BatchNormalizationLayerFixture.h index 4a6ac1af7f..7e072e7023 100644 --- a/tests/validation/fixtures/BatchNormalizationLayerFixture.h +++ b/tests/validation/fixtures/BatchNormalizationLayerFixture.h @@ -45,14 +45,20 @@ class BatchNormalizationLayerValidationFixedPointFixture : public framework::Fix { public: template <typename...> - void setup(TensorShape shape0, TensorShape shape1, float epsilon, bool use_beta, bool use_gamma, ActivationLayerInfo act_info, DataType dt, int fractional_bits) + void setup(TensorShape shape0, TensorShape shape1, float epsilon, bool use_beta, bool use_gamma, ActivationLayerInfo act_info, DataType dt, DataLayout data_layout, int fractional_bits) { _fractional_bits = fractional_bits; _data_type = dt; _use_beta = use_beta; _use_gamma = use_gamma; - _target = compute_target(shape0, shape1, epsilon, act_info, dt, fractional_bits); - _reference = compute_reference(shape0, shape1, epsilon, act_info, dt, fractional_bits); + + if(data_layout == DataLayout::NHWC) + { + permute(shape0, PermutationVector(2U, 0U, 1U)); + } + + _target = compute_target(shape0, shape1, epsilon, act_info, dt, data_layout, fractional_bits); + _reference = compute_reference(shape0, shape1, epsilon, act_info, dt, data_layout, fractional_bits); } protected: @@ -119,11 +125,11 @@ protected: } } - TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fixed_point_position) + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, DataLayout data_layout, int fixed_point_position) { // Create tensors - TensorType src = create_tensor<TensorType>(shape0, dt, 1, fixed_point_position); - TensorType dst = create_tensor<TensorType>(shape0, dt, 1, fixed_point_position); + TensorType src = create_tensor<TensorType>(shape0, dt, 1, fixed_point_position, QuantizationInfo(), data_layout); + TensorType dst = create_tensor<TensorType>(shape0, dt, 1, fixed_point_position, QuantizationInfo(), data_layout); TensorType mean = create_tensor<TensorType>(shape1, dt, 1, fixed_point_position); TensorType var = create_tensor<TensorType>(shape1, dt, 1, fixed_point_position); TensorType beta = create_tensor<TensorType>(shape1, dt, 1, fixed_point_position); @@ -166,10 +172,10 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fixed_point_position) + SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, DataLayout data_layout, int fixed_point_position) { // Create reference - SimpleTensor<T> ref_src{ shape0, dt, 1, fixed_point_position }; + SimpleTensor<T> ref_src{ shape0, dt, 1, fixed_point_position, QuantizationInfo(), data_layout }; SimpleTensor<T> ref_mean{ shape1, dt, 1, fixed_point_position }; SimpleTensor<T> ref_var{ shape1, dt, 1, fixed_point_position }; SimpleTensor<T> ref_beta{ shape1, dt, 1, fixed_point_position }; @@ -194,9 +200,9 @@ class BatchNormalizationLayerValidationFixture : public BatchNormalizationLayerV { public: template <typename...> - void setup(TensorShape shape0, TensorShape shape1, float epsilon, bool use_beta, bool use_gamma, ActivationLayerInfo act_info, DataType dt) + void setup(TensorShape shape0, TensorShape shape1, float epsilon, bool use_beta, bool use_gamma, ActivationLayerInfo act_info, DataType dt, DataLayout data_layout) { - BatchNormalizationLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, epsilon, use_beta, use_gamma, act_info, dt, 0); + BatchNormalizationLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, epsilon, use_beta, use_gamma, act_info, dt, data_layout, 0); } }; } // namespace validation diff --git a/tests/validation/reference/BatchNormalizationLayer.cpp b/tests/validation/reference/BatchNormalizationLayer.cpp index c8badacc79..ae309d9093 100644 --- a/tests/validation/reference/BatchNormalizationLayer.cpp +++ b/tests/validation/reference/BatchNormalizationLayer.cpp @@ -27,6 +27,7 @@ #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" +#include "tests/validation/reference/Permute.h" namespace arm_compute { @@ -41,6 +42,7 @@ template <typename T, typename std::enable_if<std::is_integral<T>::value, int>:: SimpleTensor<T> batch_normalization_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &mean, const SimpleTensor<T> &var, const SimpleTensor<T> &beta, const SimpleTensor<T> &gamma, float epsilon, ActivationLayerInfo act_info, int fixed_point_position) { + ARM_COMPUTE_ERROR_ON_MSG(src.data_layout() == DataLayout::NHWC, "Unsupported NHWC format"); ARM_COMPUTE_UNUSED(act_info); SimpleTensor<T> result(src.shape(), src.data_type()); @@ -86,12 +88,14 @@ SimpleTensor<T> batch_normalization_layer(const SimpleTensor<T> &src, const Simp { ARM_COMPUTE_UNUSED(fixed_point_position); - SimpleTensor<T> result(src.shape(), src.data_type()); + const bool is_nhwc = src.data_layout() == DataLayout::NHWC; + const SimpleTensor<T> perm_src = (is_nhwc) ? permute(src, PermutationVector(1U, 2U, 0U)) : src; + SimpleTensor<T> result(perm_src.shape(), perm_src.data_type()); - const auto cols = static_cast<int>(src.shape()[0]); - const auto rows = static_cast<int>(src.shape()[1]); - const auto depth = static_cast<int>(src.shape()[2]); - const int upper_dims = src.shape().total_size() / (cols * rows * depth); + const auto cols = static_cast<int>(perm_src.shape()[0]); + const auto rows = static_cast<int>(perm_src.shape()[1]); + const auto depth = static_cast<int>(perm_src.shape()[2]); + const int upper_dims = perm_src.shape().total_size() / (cols * rows * depth); for(int r = 0; r < upper_dims; ++r) { @@ -103,7 +107,7 @@ SimpleTensor<T> batch_normalization_layer(const SimpleTensor<T> &src, const Simp { const int pos = l + k * cols + i * rows * cols + r * cols * rows * depth; const float denominator = sqrt(var[i] + epsilon); - const float numerator = src[pos] - mean[i]; + const float numerator = perm_src[pos] - mean[i]; const float x_bar = numerator / denominator; result[pos] = beta[i] + x_bar * gamma[i]; } @@ -116,6 +120,10 @@ SimpleTensor<T> batch_normalization_layer(const SimpleTensor<T> &src, const Simp result = activation_layer(result, act_info); } + if(is_nhwc) + { + result = permute(result, PermutationVector(2U, 0U, 1U)); + } return result; } template SimpleTensor<float> batch_normalization_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &mean, const SimpleTensor<float> &var, const SimpleTensor<float> &beta, |