diff options
Diffstat (limited to 'framework/Framework.h')
-rw-r--r-- | framework/Framework.h | 30 |
1 files changed, 26 insertions, 4 deletions
diff --git a/framework/Framework.h b/framework/Framework.h index e9beafd6eb..fbc95badd6 100644 --- a/framework/Framework.h +++ b/framework/Framework.h @@ -24,10 +24,12 @@ #ifndef ARM_COMPUTE_TEST_FRAMEWORK #define ARM_COMPUTE_TEST_FRAMEWORK +#include "Profiler.h" #include "TestCase.h" #include "TestCaseFactory.h" #include "TestResult.h" #include "Utils.h" +#include "instruments/Instruments.h" #include <algorithm> #include <chrono> @@ -70,13 +72,20 @@ public: */ static Framework &get(); + /** Supported instrument types for benchmarking. + * + * @return Set of all available instrument types. + */ + std::set<InstrumentType> available_instruments() const; + /** Init the framework. * + * @param[in] instruments Instrument types that will be used for benchmarking. * @param[in] num_iterations Number of iterations per test. * @param[in] name_filter Regular expression to filter tests by name. Only matching tests will be executed. * @param[in] id_filter Regular expression to filter tests by id. Only matching tests will be executed. */ - void init(int num_iterations, const std::string &name_filter, const std::string &id_filter); + void init(const std::vector<InstrumentType> &instruments, int num_iterations, const std::string &name_filter, const std::string &id_filter); /** Add a new test suite. * @@ -181,6 +190,15 @@ public: */ void set_test_result(std::string test_case_name, TestResult result); + /** Factory method to obtain a configured profiler. + * + * The profiler enables all instruments that have been passed to the @ref + * init method. + * + * @return Configured profiler to collect benchmark results. + */ + Profiler get_profiler() const; + /** List of @ref TestId's. * * @return Vector with all test ids. @@ -188,7 +206,7 @@ public: std::vector<Framework::TestId> test_ids() const; private: - Framework() = default; + Framework(); ~Framework() = default; Framework(const Framework &) = delete; @@ -214,8 +232,12 @@ private: int _num_iterations{ 1 }; bool _throw_errors{ false }; - std::regex _test_name_filter{ ".*" }; - std::regex _test_id_filter{ ".*" }; + using create_function = std::unique_ptr<Instrument>(); + std::map<InstrumentType, create_function *> _available_instruments{}; + + InstrumentType _instruments{ InstrumentType::NONE }; + std::regex _test_name_filter{ ".*" }; + std::regex _test_id_filter{ ".*" }; }; template <typename T> |