From 5b61fd3fbaf41031232296abde56258d12ba3340 Mon Sep 17 00:00:00 2001 From: Moritz Pflanzer Date: Tue, 12 Sep 2017 15:51:33 +0100 Subject: COMPMID-417: Fix validation Change-Id: I7a745037136bc6e02d177f65fe4f4cd43873b98e Reviewed-on: http://mpd-gerrit.cambridge.arm.com/87406 Tested-by: Kaizen Reviewed-by: Anthony Barbier --- tests/framework/Framework.cpp | 18 ++++---- tests/framework/printers/JSONPrinter.cpp | 66 ++++++++++++++++++++++-------- tests/framework/printers/JSONPrinter.h | 11 ++++- tests/framework/printers/PrettyPrinter.cpp | 5 ++- tests/framework/printers/PrettyPrinter.h | 2 +- tests/framework/printers/Printer.h | 5 ++- tests/validation/CL/SoftmaxLayer.cpp | 12 +++--- tests/validation/NEON/SoftmaxLayer.cpp | 12 +++--- tests/validation/Validation.cpp | 4 +- tests/validation/Validation.h | 16 +++++--- 10 files changed, 101 insertions(+), 50 deletions(-) diff --git a/tests/framework/Framework.cpp b/tests/framework/Framework.cpp index 343b7a8561..31e524338b 100644 --- a/tests/framework/Framework.cpp +++ b/tests/framework/Framework.cpp @@ -164,7 +164,8 @@ void Framework::log_failed_expectation(const TestError &error) { if(_log_level >= error.level() && _printer != nullptr) { - _printer->print_error(error); + constexpr bool expected_error = true; + _printer->print_error(error, expected_error); } if(_current_test_result != nullptr) @@ -232,6 +233,8 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory) _printer->print_errors_header(); } + const bool is_expected_failure = test_factory.status() == TestCaseFactory::Status::EXPECTED_FAILURE; + try { std::unique_ptr test_case = test_factory.make(); @@ -265,7 +268,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory) { if(_log_level >= error.level() && _printer != nullptr) { - _printer->print_error(error); + _printer->print_error(error, is_expected_failure); } result.status = TestResult::Status::FAILED; @@ -282,7 +285,8 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory) { std::stringstream stream; stream << "Error code: " << error.err(); - _printer->print_error(TestError(error.what(), LogLevel::ERRORS, stream.str())); + TestError test_error(error.what(), LogLevel::ERRORS, stream.str()); + _printer->print_error(test_error, is_expected_failure); } result.status = TestResult::Status::FAILED; @@ -297,7 +301,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory) { if(_log_level >= LogLevel::ERRORS && _printer != nullptr) { - _printer->print_error(error); + _printer->print_error(error, is_expected_failure); } result.status = TestResult::Status::CRASHED; @@ -311,7 +315,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory) { if(_log_level >= LogLevel::ERRORS && _printer != nullptr) { - _printer->print_error(TestError("Received unknown exception")); + _printer->print_error(TestError("Received unknown exception"), is_expected_failure); } result.status = TestResult::Status::CRASHED; @@ -326,7 +330,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory) { if(_log_level >= LogLevel::ERRORS && _printer != nullptr) { - _printer->print_error(error); + _printer->print_error(error, is_expected_failure); } result.status = TestResult::Status::CRASHED; @@ -340,7 +344,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory) { if(_log_level >= LogLevel::ERRORS && _printer != nullptr) { - _printer->print_error(TestError("Received unknown exception")); + _printer->print_error(TestError("Received unknown exception"), is_expected_failure); } result.status = TestResult::Status::CRASHED; diff --git a/tests/framework/printers/JSONPrinter.cpp b/tests/framework/printers/JSONPrinter.cpp index ae19cae67c..4f17e6277b 100644 --- a/tests/framework/printers/JSONPrinter.cpp +++ b/tests/framework/printers/JSONPrinter.cpp @@ -46,6 +46,31 @@ void JSONPrinter::print_separator(bool &flag) } } +template +void JSONPrinter::print_strings(T &&first, T &&last) +{ + bool first_entry = true; + std::stringstream log; + + while(first != last) + { + print_separator(first_entry); + + *_stream << R"(")"; + + log.str(*first); + + for(std::string line; !std::getline(log, line).fail();) + { + *_stream << line << "; "; + } + + *_stream << R"(")"; + + ++first; + } +} + void JSONPrinter::print_entry(const std::string &name, const std::string &value) { print_separator(_first_entry); @@ -90,38 +115,43 @@ void JSONPrinter::print_test_footer() void JSONPrinter::print_errors_header() { - print_separator(_first_test_entry); - - _first_error = true; - *_stream << R"("errors" : [)"; + _errors.clear(); + _expected_errors.clear(); + _infos.clear(); } void JSONPrinter::print_errors_footer() { + print_separator(_first_test_entry); + + *_stream << R"("errors" : [)"; + print_strings(_errors.begin(), _errors.end()); + *_stream << "]"; + + *_stream << R"(, "expected_errors" : [)"; + print_strings(_expected_errors.begin(), _expected_errors.end()); + *_stream << "]"; + + *_stream << R"(, "infos" : [)"; + print_strings(_infos.begin(), _infos.end()); *_stream << "]"; } -void JSONPrinter::print_error(const std::exception &error) +void JSONPrinter::print_error(const std::exception &error, bool expected) { - std::stringstream error_log; - error_log.str(error.what()); - - for(std::string line; !std::getline(error_log, line).fail();) + if(expected) { - print_separator(_first_error); - - *_stream << R"(")" << line << R"(")"; + _expected_errors.emplace_back(error.what()); + } + else + { + _errors.emplace_back(error.what()); } } void JSONPrinter::print_info(const std::string &info) { - std::istringstream iss(info); - for(std::string line; !std::getline(iss, line).fail();) - { - print_separator(_first_error); - *_stream << R"(")" << line << R"(")"; - } + _infos.push_back(info); } void JSONPrinter::print_measurements(const Profiler::MeasurementsMap &measurements) diff --git a/tests/framework/printers/JSONPrinter.h b/tests/framework/printers/JSONPrinter.h index 18bd4380b0..a2811ea41b 100644 --- a/tests/framework/printers/JSONPrinter.h +++ b/tests/framework/printers/JSONPrinter.h @@ -26,6 +26,8 @@ #include "Printer.h" +#include + namespace arm_compute { namespace test @@ -47,17 +49,22 @@ public: void print_test_footer() override; void print_errors_header() override; void print_errors_footer() override; - void print_error(const std::exception &error) override; + void print_error(const std::exception &error, bool expected) override; void print_info(const std::string &info) override; void print_measurements(const Profiler::MeasurementsMap &measurements) override; private: void print_separator(bool &flag); + template + void print_strings(T &&first, T &&last); + + std::list _infos{}; + std::list _errors{}; + std::list _expected_errors{}; bool _first_entry{ true }; bool _first_test{ true }; bool _first_test_entry{ true }; - bool _first_error{ true }; }; } // namespace framework } // namespace test diff --git a/tests/framework/printers/PrettyPrinter.cpp b/tests/framework/printers/PrettyPrinter.cpp index 2f7df1837a..5eec72a2fe 100644 --- a/tests/framework/printers/PrettyPrinter.cpp +++ b/tests/framework/printers/PrettyPrinter.cpp @@ -102,9 +102,10 @@ void PrettyPrinter::print_info(const std::string &info) *_stream << begin_color("1") << "INFO: " << info << end_color() << "\n"; } -void PrettyPrinter::print_error(const std::exception &error) +void PrettyPrinter::print_error(const std::exception &error, bool expected) { - *_stream << begin_color("1") << "ERROR: " << error.what() << end_color() << "\n"; + std::string prefix = expected ? "EXPECTED ERROR: " : "ERROR: "; + *_stream << begin_color("1") << prefix << error.what() << end_color() << "\n"; } void PrettyPrinter::print_measurements(const Profiler::MeasurementsMap &measurements) diff --git a/tests/framework/printers/PrettyPrinter.h b/tests/framework/printers/PrettyPrinter.h index 3e2bebd1f7..f72a613868 100644 --- a/tests/framework/printers/PrettyPrinter.h +++ b/tests/framework/printers/PrettyPrinter.h @@ -53,7 +53,7 @@ public: void print_test_footer() override; void print_errors_header() override; void print_errors_footer() override; - void print_error(const std::exception &error) override; + void print_error(const std::exception &error, bool expected) override; void print_info(const std::string &info) override; void print_measurements(const Profiler::MeasurementsMap &measurements) override; diff --git a/tests/framework/printers/Printer.h b/tests/framework/printers/Printer.h index 16a4170d7a..c2a44240ba 100644 --- a/tests/framework/printers/Printer.h +++ b/tests/framework/printers/Printer.h @@ -104,9 +104,10 @@ public: /** Print test error. * - * @param[in] error Description of the error. + * @param[in] error Description of the error. + * @param[in] expected Whether the error was expected or not. */ - virtual void print_error(const std::exception &error) = 0; + virtual void print_error(const std::exception &error, bool expected) = 0; /** Print test log info. * diff --git a/tests/validation/CL/SoftmaxLayer.cpp b/tests/validation/CL/SoftmaxLayer.cpp index 6a22eb1bcc..8c143ecd96 100644 --- a/tests/validation/CL/SoftmaxLayer.cpp +++ b/tests/validation/CL/SoftmaxLayer.cpp @@ -48,7 +48,7 @@ RelativeTolerance tolerance_f16(half_float::half(0.2)); RelativeTolerance tolerance_f32(0.001f); /** Tolerance for fixed point operations */ -constexpr AbsoluteTolerance tolerance_fixed_point(2); +constexpr AbsoluteTolerance tolerance_fixed_point(2); /** CNN data types */ const auto CNNDataTypes = framework::dataset::make("DataType", @@ -145,15 +145,17 @@ TEST_SUITE_END() TEST_SUITE(QS16) // Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14 -FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", - DataType::QS16)), +FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataType", + DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_fixed_point); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixedPointFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", - DataType::QS16)), +FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixedPointFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataType", + DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp index 36f1881147..7ac7759c22 100644 --- a/tests/validation/NEON/SoftmaxLayer.cpp +++ b/tests/validation/NEON/SoftmaxLayer.cpp @@ -49,7 +49,7 @@ constexpr AbsoluteTolerance tolerance_f32(0.000001f); constexpr AbsoluteTolerance tolerance_f16(0.0001f); #endif /* ARM_COMPUTE_ENABLE_FP16*/ /** Tolerance for fixed point operations */ -constexpr AbsoluteTolerance tolerance_fixed_point(2); +constexpr AbsoluteTolerance tolerance_fixed_point(2); /** CNN data types */ const auto CNNDataTypes = framework::dataset::make("DataType", @@ -151,15 +151,17 @@ TEST_SUITE_END() TEST_SUITE(QS16) // Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14 -FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", - DataType::QS16)), +FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataType", + DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output validate(Accessor(_target), _reference, tolerance_fixed_point); } -FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", - DataType::QS16)), +FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataType", + DataType::QS16)), framework::dataset::make("FractionalBits", 1, 14))) { // Validate output diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp index 690c4eac9e..1a082111a9 100644 --- a/tests/validation/Validation.cpp +++ b/tests/validation/Validation.cpp @@ -130,7 +130,7 @@ void check_border_element(const IAccessor &tensor, const Coordinates &id, const double target = get_double_data(ptr + channel_offset, tensor.data_type()); const double reference = get_double_data(static_cast(border_value) + channel_offset, tensor.data_type()); - if(!compare, double>(target, reference)) + if(!compare>(target, reference)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << channel); @@ -192,7 +192,7 @@ void validate(const IAccessor &tensor, const void *reference_value) const double target = get_double_data(ptr + channel_offset, tensor.data_type()); const double reference = get_double_data(reference_value, tensor.data_type()); - if(!compare, double>(target, reference)) + if(!compare>(target, reference)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << channel); diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index e70c970cc1..6bc42a4ed6 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -226,11 +226,11 @@ struct compare_base T _tolerance{}; }; -template +template struct compare; template -struct compare, U> : public compare_base> +struct compare> : public compare_base> { using compare_base>::compare_base; @@ -245,12 +245,16 @@ struct compare, U> : public compare_base(std::abs(this->_target - this->_reference)) <= static_cast(this->_tolerance); + using comparison_type = typename std::conditional::value, int64_t, U>::type; + + const comparison_type abs_difference(std::abs(static_cast(this->_target) - static_cast(this->_reference))); + + return abs_difference <= static_cast(this->_tolerance); } }; template -struct compare, U> : public compare_base> +struct compare> : public compare_base> { using compare_base>::compare_base; @@ -325,7 +329,7 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, const V const T &target_value = reinterpret_cast(tensor(id))[c]; const T &reference_value = reinterpret_cast(reference(id))[c]; - if(!compare(target_value, reference_value, tolerance_value)) + if(!compare(target_value, reference_value, tolerance_value)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << c); @@ -359,7 +363,7 @@ void validate(T target, T reference, U tolerance) ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference)); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target)); ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast(tolerance))); - ARM_COMPUTE_EXPECT((compare(target, reference, tolerance)), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((compare(target, reference, tolerance)), framework::LogLevel::ERRORS); } } // namespace validation } // namespace test -- cgit v1.2.1