From 81870c05533cba03373d5e51fed95cd5e74f741d Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 30 Apr 2020 12:02:20 +0100 Subject: IVGCVSW-4743: Fix CpuAcc Hal 1.3 Comparison Failures Broadcast for QASYMM8_SIGNED was not handled. Change-Id: Id5dbb0dce78838319218de94551bba52d697f4a4 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3131 Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- tests/validation/NEON/Comparisons.cpp | 37 ++++++++++++--------------- tests/validation/fixtures/ComparisonFixture.h | 13 +++++++++- 2 files changed, 29 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/tests/validation/NEON/Comparisons.cpp b/tests/validation/NEON/Comparisons.cpp index 38e440e649..f080c834e5 100644 --- a/tests/validation/NEON/Comparisons.cpp +++ b/tests/validation/NEON/Comparisons.cpp @@ -52,8 +52,9 @@ const auto configure_dataset = combine(datasets::SmallShapes(), DataType::F32 })); -const auto run_small_dataset = combine(datasets::ComparisonOperations(), datasets::SmallShapes()); -const auto run_large_dataset = combine(datasets::ComparisonOperations(), datasets::LargeShapes()); +const auto run_small_dataset = combine(datasets::ComparisonOperations(), datasets::SmallShapes()); +const auto run_small_broadcast_dataset = combine(datasets::ComparisonOperations(), datasets::SmallShapesBroadcast()); +const auto run_large_dataset = combine(datasets::ComparisonOperations(), datasets::LargeShapes()); } // namespace @@ -90,23 +91,6 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( // clang-format on // *INDENT-ON* -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, configure_dataset, - shape, data_type) -{ - // Create tensors - Tensor ref_src1 = create_tensor(shape, data_type); - Tensor ref_src2 = create_tensor(shape, data_type); - Tensor dst = create_tensor(shape, DataType::U8); - - // Create and Configure function - NEElementwiseComparison compare; - compare.configure(&ref_src1, &ref_src2, &dst, ComparisonOperation::Equal); - - // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(shape); - validate(dst.info()->valid_region(), valid_region); -} - template using NEComparisonFixture = ComparisonValidationFixture; @@ -154,6 +138,8 @@ TEST_SUITE_END() // Float template using NEComparisonQuantizedFixture = ComparisonValidationQuantizedFixture; +template +using NEComparisonQuantizedBroadcastFixture = ComparisonQuantizedBroadcastValidationFixture; TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) @@ -169,11 +155,22 @@ FIXTURE_DATA_TEST_CASE(RunSmall, } TEST_SUITE_END() TEST_SUITE(QASYMM8_SIGNED) +FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, + NEComparisonQuantizedBroadcastFixture, + framework::DatasetMode::ALL, + combine(combine(combine(run_small_broadcast_dataset, framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1, -30) })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.3f, 2) }))) +{ + // Validate output + validate(Accessor(_target), _reference); +} + FIXTURE_DATA_TEST_CASE(RunSmall, NEComparisonQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("QuantizationInfo", { QuantizationInfo() })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1, -30) })), framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.3f, 2) }))) { // Validate output diff --git a/tests/validation/fixtures/ComparisonFixture.h b/tests/validation/fixtures/ComparisonFixture.h index b2fe42de26..d1e1a539c7 100644 --- a/tests/validation/fixtures/ComparisonFixture.h +++ b/tests/validation/fixtures/ComparisonFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -146,6 +146,17 @@ public: ComparisonValidationGenericFixture::setup(op, shape, shape, data_type, qinfo0, qinfo1); } }; + +template +class ComparisonQuantizedBroadcastValidationFixture : public ComparisonValidationGenericFixture +{ +public: + template + void setup(ComparisonOperation op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, QuantizationInfo qinfo0, QuantizationInfo qinfo1) + { + ComparisonValidationGenericFixture::setup(op, shape0, shape1, data_type, qinfo0, qinfo1); + } +}; } // namespace validation } // namespace test } // namespace arm_compute -- cgit v1.2.1