diff options
Diffstat (limited to 'framework/Framework.cpp')
-rw-r--r-- | framework/Framework.cpp | 286 |
1 files changed, 286 insertions, 0 deletions
diff --git a/framework/Framework.cpp b/framework/Framework.cpp new file mode 100644 index 0000000000..b54c0c75b6 --- /dev/null +++ b/framework/Framework.cpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "Framework.h" + +#include "Exceptions.h" +#include "support/ToolchainSupport.h" + +#include <chrono> +#include <iostream> +#include <sstream> +#include <type_traits> + +namespace arm_compute +{ +namespace test +{ +namespace framework +{ +std::tuple<int, int, int> Framework::count_test_results() const +{ + int passed = 0; + int failed = 0; + int crashed = 0; + + for(const auto &test : _test_results) + { + switch(test.second.status) + { + case TestResult::Status::SUCCESS: + ++passed; + break; + case TestResult::Status::FAILED: + ++failed; + break; + case TestResult::Status::CRASHED: + ++crashed; + break; + default: + // Do nothing + break; + } + } + + return std::make_tuple(passed, failed, crashed); +} + +Framework &Framework::get() +{ + static Framework instance; + return instance; +} + +void Framework::init(int num_iterations, const std::string &name_filter, const std::string &id_filter) +{ + _test_name_filter = std::regex{ name_filter }; + _test_id_filter = std::regex{ id_filter }; + _num_iterations = num_iterations; +} + +std::string Framework::current_suite_name() const +{ + return join(_test_suite_name.cbegin(), _test_suite_name.cend(), "/"); +} + +void Framework::push_suite(std::string name) +{ + _test_suite_name.emplace_back(std::move(name)); +} + +void Framework::pop_suite() +{ + _test_suite_name.pop_back(); +} + +void Framework::log_test_start(const std::string &test_name) +{ + static_cast<void>(test_name); +} + +void Framework::log_test_skipped(const std::string &test_name) +{ + static_cast<void>(test_name); +} + +void Framework::log_test_end(const std::string &test_name) +{ + static_cast<void>(test_name); +} + +void Framework::log_failed_expectation(const std::string &msg) +{ + std::cerr << "ERROR: " << msg << "\n"; +} + +int Framework::num_iterations() const +{ + return _num_iterations; +} + +void Framework::set_num_iterations(int num_iterations) +{ + _num_iterations = num_iterations; +} + +void Framework::set_throw_errors(bool throw_errors) +{ + _throw_errors = throw_errors; +} + +bool Framework::throw_errors() const +{ + return _throw_errors; +} + +bool Framework::is_enabled(const TestId &id) const +{ + return (std::regex_search(support::cpp11::to_string(id.first), _test_id_filter) && std::regex_search(id.second, _test_name_filter)); +} + +void Framework::run_test(TestCaseFactory &test_factory) +{ + const std::string test_case_name = test_factory.name(); + + log_test_start(test_case_name); + + TestResult result; + + try + { + std::unique_ptr<TestCase> test_case = test_factory.make(); + + try + { + test_case->do_setup(); + + for(int i = 0; i < _num_iterations; ++i) + { + test_case->do_run(); + } + + test_case->do_teardown(); + + result.status = TestResult::Status::SUCCESS; + } + catch(const TestError &error) + { + std::cerr << "FATAL ERROR: " << error.what() << "\n"; + result.status = TestResult::Status::FAILED; + + if(_throw_errors) + { + throw; + } + } + catch(const std::exception &error) + { + std::cerr << "FATAL ERROR: Received unhandled error: '" << error.what() << "'\n"; + result.status = TestResult::Status::CRASHED; + + if(_throw_errors) + { + throw; + } + } + catch(...) + { + std::cerr << "FATAL ERROR: Received unhandled exception\n"; + result.status = TestResult::Status::CRASHED; + + if(_throw_errors) + { + throw; + } + } + } + catch(const std::exception &error) + { + std::cerr << "FATAL ERROR: Received unhandled error during fixture creation: '" << error.what() << "'\n"; + + if(_throw_errors) + { + throw; + } + } + catch(...) + { + std::cerr << "FATAL ERROR: Received unhandled exception during fixture creation\n"; + + if(_throw_errors) + { + throw; + } + } + + set_test_result(test_case_name, result); + log_test_end(test_case_name); +} + +bool Framework::run() +{ + // Clear old test results + _test_results.clear(); + _runtime = std::chrono::seconds{ 0 }; + + const auto start = std::chrono::high_resolution_clock::now(); + + int id = 0; + + for(auto &test_factory : _test_factories) + { + const std::string test_case_name = test_factory->name(); + + if(!is_enabled(TestId(id, test_case_name))) + { + log_test_skipped(test_case_name); + } + else + { + run_test(*test_factory); + } + + ++id; + } + + const auto end = std::chrono::high_resolution_clock::now(); + + _runtime = std::chrono::duration_cast<std::chrono::seconds>(end - start); + + int passed = 0; + int failed = 0; + int crashed = 0; + + std::tie(passed, failed, crashed) = count_test_results(); + + std::cout << "Executed " << _test_results.size() << " test(s) (" << passed << " passed, " << failed << " failed, " << crashed << " crashed) in " << _runtime.count() << " second(s)\n"; + + return (static_cast<unsigned int>(passed) == _test_results.size()); +} + +void Framework::set_test_result(std::string test_case_name, TestResult result) +{ + _test_results.emplace(std::move(test_case_name), result); +} + +std::vector<Framework::TestId> Framework::test_ids() const +{ + std::vector<TestId> ids; + + int id = 0; + + for(const auto &factory : _test_factories) + { + if(is_enabled(TestId(id, factory->name()))) + { + ids.emplace_back(id, factory->name()); + } + + ++id; + } + + return ids; +} +} // namespace framework +} // namespace test +} // namespace arm_compute |