From aeebe4ad332bc6dd4797243b39ec6000975429c5 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Wed, 9 Jan 2019 14:21:03 +0000 Subject: COMPMID-1652 NEON Cleanup and add missing tests Change-Id: I7b351f18a78ed8a250bf3a91ef320db61984146a Reviewed-on: https://review.mlplatform.org/485 Reviewed-by: Isabella Gottardi Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- .../GLES_COMPUTE/BatchNormalizationLayer.cpp | 28 ++++++++++------------ 1 file changed, 12 insertions(+), 16 deletions(-) (limited to 'tests/validation/GLES_COMPUTE') diff --git a/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp b/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp index 3a3d1d796d..d08459d20d 100644 --- a/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp +++ b/tests/validation/GLES_COMPUTE/BatchNormalizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -52,6 +52,13 @@ const auto act_infos = framework::dataset::make("Activat ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), }); + +const auto data_GB = combine(framework::dataset::make("UseGamma", { false, true }), + framework::dataset::make("UseBeta", { false, true })); +const auto data_f16 = combine(combine(combine(data_GB, act_infos), framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); +const auto data_f32 = combine(combine(combine(data_GB, act_infos), framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW })); } // namespace TEST_SUITE(GC) @@ -60,9 +67,8 @@ TEST_SUITE(BatchNormalizationLayer) template using GCBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture; -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - combine(framework::dataset::make("UseBeta", { false, true }), - framework::dataset::make("UseGamma", { false, true }))), +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallRandomBatchNormalizationLayerDataset(), + data_GB), framework::dataset::make("DataType", { DataType::F32 })), framework::dataset::make("DataLayout", { DataLayout::NCHW })), shape0, shape1, epsilon, use_beta, use_gamma, dt, data_layout) @@ -94,12 +100,7 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi TEST_SUITE(Float) TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - combine(framework::dataset::make("UseBeta", { false, true }), - framework::dataset::make("UseGamma", { false, true }))), - act_infos), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NCHW }))) +FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallRandomBatchNormalizationLayerDataset(), data_f16)) { // Validate output validate(GCAccessor(_target), _reference, tolerance_f16, 0); @@ -107,12 +108,7 @@ FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework TEST_SUITE_END() TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::RandomBatchNormalizationLayerDataset(), - combine(framework::dataset::make("UseBeta", { false, true }), - framework::dataset::make("UseGamma", { false, true }))), - act_infos), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NCHW }))) +FIXTURE_DATA_TEST_CASE(Random, GCBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::LargeRandomBatchNormalizationLayerDataset(), data_f32)) { // Validate output validate(GCAccessor(_target), _reference, tolerance_f, 0); -- cgit v1.2.1