diff options
Diffstat (limited to 'src/runtime/CL/mlgo/MLGOHeuristics.cpp')
-rw-r--r-- | src/runtime/CL/mlgo/MLGOHeuristics.cpp | 99 |
1 files changed, 51 insertions, 48 deletions
diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.cpp b/src/runtime/CL/mlgo/MLGOHeuristics.cpp index 80f3bb85e9..aed46cd80f 100644 --- a/src/runtime/CL/mlgo/MLGOHeuristics.cpp +++ b/src/runtime/CL/mlgo/MLGOHeuristics.cpp @@ -24,6 +24,7 @@ #include "src/runtime/CL/mlgo/MLGOHeuristics.h" #include "arm_compute/core/Log.h" + #include "src/runtime/CL/mlgo/MLGOParser.h" #include "src/runtime/CL/mlgo/Utils.h" @@ -39,19 +40,19 @@ bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs) } bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs) { - return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs, - rhs.export_cl_image); + return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == + std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image); } bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs) { - return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0, - rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image); + return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, + lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0, rhs.interleave_lhs, + rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image); } constexpr size_t MLGOHeuristics::_max_num_trees; -MLGOHeuristics::MLGOHeuristics() - : _indices{}, _trees{}, _tree_valid{}, _valid{ false } +MLGOHeuristics::MLGOHeuristics() : _indices{}, _trees{}, _tree_valid{}, _valid{false} { } @@ -59,71 +60,74 @@ std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(const Query &query) co { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str()); const auto invalid = GEMMType::RESHAPED; - if(!_valid) + if (!_valid) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); - return { false, invalid }; + return {false, invalid}; } auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type); - GEMMShape shape_query{ query.m, query.n, query.k, query.b }; - if(_trees.find(index) == _trees.end()) + GEMMShape shape_query{query.m, query.n, query.k, query.b}; + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); - return { false, invalid }; + return {false, invalid}; } return _trees.at(index).query<GEMMType>(shape_query); } std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(const Query &query) const { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str()); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", + to_string(query).c_str()); const auto invalid = GEMMConfigNative{}; - if(!_valid) + if (!_valid) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); - return { false, invalid }; + return {false, invalid}; } auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type); - GEMMShape shape_query{ query.m, query.n, query.k, query.b }; - if(_trees.find(index) == _trees.end()) + GEMMShape shape_query{query.m, query.n, query.k, query.b}; + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); - return { false, invalid }; + return {false, invalid}; } return _trees.at(index).query<GEMMConfigNative>(shape_query); } std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(const Query &query) const { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str()); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", + to_string(query).c_str()); const auto invalid = GEMMConfigReshapedOnlyRHS{}; - if(!_valid) + if (!_valid) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); - return { false, invalid }; + return {false, invalid}; } auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type); - GEMMShape shape_query{ query.m, query.n, query.k, query.b }; - if(_trees.find(index) == _trees.end()) + GEMMShape shape_query{query.m, query.n, query.k, query.b}; + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); - return { false, invalid }; + return {false, invalid}; } return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query); } std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(const Query &query) const { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str()); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", + to_string(query).c_str()); const auto invalid = GEMMConfigReshaped{}; - if(!_valid) + if (!_valid) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); - return { false, invalid }; + return {false, invalid}; } auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type); - GEMMShape shape_query{ query.m, query.n, query.k, query.b }; - if(_trees.find(index) == _trees.end()) + GEMMShape shape_query{query.m, query.n, query.k, query.b}; + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); - return { false, invalid }; + return {false, invalid}; } return _trees.at(index).query<GEMMConfigReshaped>(shape_query); } @@ -131,14 +135,14 @@ std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(c bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id) { bool status; - HeuristicTree *tree{ nullptr }; + HeuristicTree *tree{nullptr}; std::tie(status, tree) = get_heuristic_tree(id); - if(!status) + if (!status) { return status; } status = tree->check(); - if(!status) + if (!status) { return status; } @@ -149,14 +153,12 @@ bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id) bool MLGOHeuristics::check_all() const { // Tree validities are already checked and cached. - bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v) - { - return !v.second; - }) - == _tree_valid.end(); - if(!all_trees_are_checked) + bool all_trees_are_checked = + std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v) { return !v.second; }) == _tree_valid.end(); + if (!all_trees_are_checked) { - ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo"); + ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each " + "tree is completed. This could also indicate there are no trees in the dotmlgo"); return false; } @@ -167,14 +169,14 @@ bool MLGOHeuristics::check_all() const std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id) { - if(_indices.find(id) == _indices.end()) + if (_indices.find(id) == _indices.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id); return std::make_pair(false, nullptr); } const auto index = _indices[id]; - if(_trees.find(index) == _trees.end()) + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); return std::make_pair(false, nullptr); @@ -186,7 +188,7 @@ std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTre bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t) { - if(_indices.size() >= _max_num_trees) + if (_indices.size() >= _max_num_trees) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees); return false; @@ -194,7 +196,7 @@ bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t) // PRE: correctness of t is guaranteed by the tree construction process // Ensure unique id const auto id = t.id(); - if(_indices.find(id) != _indices.end()) + if (_indices.find(id) != _indices.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id); return false; @@ -202,7 +204,7 @@ bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t) // Ensure unique index const auto index = t.index(); - if(_trees.find(index) != _trees.end()) + if (_trees.find(index) != _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists"); return false; @@ -219,9 +221,10 @@ bool MLGOHeuristics::reload_from_file(const std::string &filename) std::ifstream fs; fs.exceptions(std::ifstream::badbit); fs.open(filename, std::ios::in); - if(!fs.is_open()) + if (!fs.is_open()) { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str()); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", + filename.c_str()); return _valid = false; } return reload_from_stream(fs); @@ -230,7 +233,7 @@ bool MLGOHeuristics::reload_from_file(const std::string &filename) bool MLGOHeuristics::reload_from_stream(std::istream &in) { auto parsed = parser::parse_mlgo(in); - if(!parsed.first) + if (!parsed.first) { ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead"); return _valid = false; @@ -241,4 +244,4 @@ bool MLGOHeuristics::reload_from_stream(std::istream &in) } } // namespace mlgo -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |