From 8d5d78ba48358e5c511d4c625c17d99065763945 Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Tue, 15 Dec 2020 20:25:31 +0000 Subject: COMPMID-3871: Create BatchNormalization SVE/SVE2 1. Decouple data type for NHWC 2. Add NHWC SVE support for BachNormalization Signed-off-by: Sheri Zhang Change-Id: I0383b969b555b429d9acebb4efa17ecba9429ea7 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4755 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michalis Spyrou --- tests/validation/NEON/BatchNormalizationLayer.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/validation/NEON/BatchNormalizationLayer.cpp b/tests/validation/NEON/BatchNormalizationLayer.cpp index 067c5bb198..b24357f8ad 100644 --- a/tests/validation/NEON/BatchNormalizationLayer.cpp +++ b/tests/validation/NEON/BatchNormalizationLayer.cpp @@ -51,8 +51,10 @@ namespace RelativeTolerance rel_tolerance_f32(0.05f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ constexpr AbsoluteTolerance abs_tolerance_f32(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -constexpr AbsoluteTolerance tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +RelativeTolerance rel_tolerance_f16(0.05f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ +constexpr AbsoluteTolerance abs_tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const auto act_infos = framework::dataset::make("ActivationInfo", { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), @@ -148,7 +150,7 @@ FIXTURE_DATA_TEST_CASE(RandomSmall, NEBatchNormalizationLayerFixture, fram framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) { // Validate output - validate(Accessor(_target), _reference, tolerance_f16, 0); + validate(Accessor(_target), _reference, abs_tolerance_f16, 0); } FIXTURE_DATA_TEST_CASE(RandomLarge, NEBatchNormalizationLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::LargeRandomBatchNormalizationLayerDataset(), @@ -159,7 +161,7 @@ FIXTURE_DATA_TEST_CASE(RandomLarge, NEBatchNormalizationLayerFixture, fram framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) { // Validate output - validate(Accessor(_target), _reference, tolerance_f16, 0); + validate(Accessor(_target), _reference, abs_tolerance_f16, 0); } TEST_SUITE_END() // FP16 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -- cgit v1.2.1