aboutsummaryrefslogtreecommitdiff
path: root/framework/Framework.h
diff options
context:
space:
mode:
Diffstat (limited to 'framework/Framework.h')
-rw-r--r--framework/Framework.h30
1 files changed, 26 insertions, 4 deletions
diff --git a/framework/Framework.h b/framework/Framework.h
index e9beafd6eb..fbc95badd6 100644
--- a/framework/Framework.h
+++ b/framework/Framework.h
@@ -24,10 +24,12 @@
#ifndef ARM_COMPUTE_TEST_FRAMEWORK
#define ARM_COMPUTE_TEST_FRAMEWORK
+#include "Profiler.h"
#include "TestCase.h"
#include "TestCaseFactory.h"
#include "TestResult.h"
#include "Utils.h"
+#include "instruments/Instruments.h"
#include <algorithm>
#include <chrono>
@@ -70,13 +72,20 @@ public:
*/
static Framework &get();
+ /** Supported instrument types for benchmarking.
+ *
+ * @return Set of all available instrument types.
+ */
+ std::set<InstrumentType> available_instruments() const;
+
/** Init the framework.
*
+ * @param[in] instruments Instrument types that will be used for benchmarking.
* @param[in] num_iterations Number of iterations per test.
* @param[in] name_filter Regular expression to filter tests by name. Only matching tests will be executed.
* @param[in] id_filter Regular expression to filter tests by id. Only matching tests will be executed.
*/
- void init(int num_iterations, const std::string &name_filter, const std::string &id_filter);
+ void init(const std::vector<InstrumentType> &instruments, int num_iterations, const std::string &name_filter, const std::string &id_filter);
/** Add a new test suite.
*
@@ -181,6 +190,15 @@ public:
*/
void set_test_result(std::string test_case_name, TestResult result);
+ /** Factory method to obtain a configured profiler.
+ *
+ * The profiler enables all instruments that have been passed to the @ref
+ * init method.
+ *
+ * @return Configured profiler to collect benchmark results.
+ */
+ Profiler get_profiler() const;
+
/** List of @ref TestId's.
*
* @return Vector with all test ids.
@@ -188,7 +206,7 @@ public:
std::vector<Framework::TestId> test_ids() const;
private:
- Framework() = default;
+ Framework();
~Framework() = default;
Framework(const Framework &) = delete;
@@ -214,8 +232,12 @@ private:
int _num_iterations{ 1 };
bool _throw_errors{ false };
- std::regex _test_name_filter{ ".*" };
- std::regex _test_id_filter{ ".*" };
+ using create_function = std::unique_ptr<Instrument>();
+ std::map<InstrumentType, create_function *> _available_instruments{};
+
+ InstrumentType _instruments{ InstrumentType::NONE };
+ std::regex _test_name_filter{ ".*" };
+ std::regex _test_id_filter{ ".*" };
};
template <typename T>