diff options
Diffstat (limited to 'src/runtime/CL/mlgo/HeuristicTree.cpp')
-rw-r--r-- | src/runtime/CL/mlgo/HeuristicTree.cpp | 89 |
1 files changed, 41 insertions, 48 deletions
diff --git a/src/runtime/CL/mlgo/HeuristicTree.cpp b/src/runtime/CL/mlgo/HeuristicTree.cpp index 1c75cdc427..f7b706902b 100644 --- a/src/runtime/CL/mlgo/HeuristicTree.cpp +++ b/src/runtime/CL/mlgo/HeuristicTree.cpp @@ -22,6 +22,7 @@ * SOFTWARE. */ #include "src/runtime/CL/mlgo/HeuristicTree.h" + #include "arm_compute/core/Log.h" #include "support/Cast.h" @@ -40,27 +41,23 @@ bool evaluate(GEMMShape shape, Condition cond) // PRE: all features and ConditionalOps are valid constexpr float eps = 0.0001f; // Calculate all secondary features - std::vector<std::pair<std::string, float>> cond_values - { - { "m", static_cast<float>(shape.m) }, - { "n", static_cast<float>(shape.n) }, - { "k", static_cast<float>(shape.k) }, - { "b", static_cast<float>(shape.b) }, - { "r_mn", static_cast<float>(shape.m) / shape.n }, - { "r_mk", static_cast<float>(shape.m) / shape.k }, - { "r_nk", static_cast<float>(shape.n) / shape.k }, - { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) }, - { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 } - }; - auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(), - [&cond](decltype(*cond_values.begin()) it) - { - return it.first == cond.feature; - }); + std::vector<std::pair<std::string, float>> cond_values{ + {"m", static_cast<float>(shape.m)}, + {"n", static_cast<float>(shape.n)}, + {"k", static_cast<float>(shape.k)}, + {"b", static_cast<float>(shape.b)}, + {"r_mn", static_cast<float>(shape.m) / shape.n}, + {"r_mk", static_cast<float>(shape.m) / shape.k}, + {"r_nk", static_cast<float>(shape.n) / shape.k}, + {"r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k)}, + {"workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0}}; + auto cond_value_pair_it = + std::find_if(cond_values.begin(), cond_values.end(), + [&cond](decltype(*cond_values.begin()) it) { return it.first == cond.feature; }); ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end()); const float cond_value = cond_value_pair_it->second; - switch(cond.op) + switch (cond.op) { case ConditionalOp::LT: { @@ -92,13 +89,12 @@ constexpr size_t HeuristicTree::_max_num_nodes; constexpr size_t HeuristicTree::_max_query_depth; constexpr HeuristicTree::NodeID HeuristicTree::_root; -HeuristicTree::HeuristicTree() - : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32) +HeuristicTree::HeuristicTree() : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32) { } HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type) - : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{} + : _id{id}, _heuristic_type{h_type}, _ip_target{ip_target}, _data_type{data_type}, _tree{} { } @@ -108,16 +104,17 @@ std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const // Root ID = 0; auto cur_node = _tree.at(_root).get(); size_t depth = 0; - while(cur_node->type() != NodeType::Leaf) + while (cur_node->type() != NodeType::Leaf) { - if(depth > _max_query_depth) + if (depth > _max_query_depth) { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", + _max_query_depth); return std::make_pair(false, T{}); } ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType"); auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node); - if(evaluate(shape, br_node->condition)) + if (evaluate(shape, br_node->condition)) { cur_node = _tree.at(br_node->true_node).get(); } @@ -135,12 +132,12 @@ std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const template <typename T> bool HeuristicTree::add_leaf(NodeID id, T val) { - if(_tree.size() >= _max_num_nodes) + if (_tree.size() >= _max_num_nodes) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); return false; } - if(_tree.find(id) != _tree.end()) + if (_tree.find(id) != _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); return false; @@ -151,28 +148,23 @@ bool HeuristicTree::add_leaf(NodeID id, T val) bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node) { - if(_tree.size() >= _max_num_nodes) + if (_tree.size() >= _max_num_nodes) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); return false; } - const std::set<std::string> supported_features = - { - "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload" - }; - const auto orig_feature = cond.feature; - std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c) - { - return std::tolower(c); - }); - if(supported_features.find(cond.feature) == supported_features.end()) + const std::set<std::string> supported_features = {"m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"}; + const auto orig_feature = cond.feature; + std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), + [](char c) { return std::tolower(c); }); + if (supported_features.find(cond.feature) == supported_features.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str()); return false; } - if(_tree.find(id) != _tree.end()) + if (_tree.find(id) != _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); return false; @@ -184,32 +176,32 @@ bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID bool HeuristicTree::check_if_structurally_correct() const { std::set<NodeID> visited; - std::deque<NodeID> to_visit{ _root }; + std::deque<NodeID> to_visit{_root}; - while(!to_visit.empty()) + while (!to_visit.empty()) { auto id = to_visit.front(); to_visit.pop_front(); - if(_tree.find(id) == _tree.end()) + if (_tree.find(id) == _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id); return false; } auto not_seen_before = visited.insert(id); - if(!not_seen_before.second) + if (!not_seen_before.second) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops"); return false; } auto cur_node = _tree.at(id).get(); - if(cur_node->type() == NodeType::Branch) + if (cur_node->type() == NodeType::Branch) { auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node); to_visit.push_back(br_node->true_node); to_visit.push_back(br_node->false_node); } } - if(visited.size() != _tree.size()) + if (visited.size() != _tree.size()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes"); return false; @@ -219,12 +211,12 @@ bool HeuristicTree::check_if_structurally_correct() const bool HeuristicTree::check() { - if(_tree.empty()) + if (_tree.empty()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered"); return false; } - if(_tree.find(_root) == _tree.end()) + if (_tree.find(_root) == _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root); return false; @@ -237,7 +229,8 @@ template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shap /** Explicit template instantiation @relates HeuristicTree */ template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const; /** Explicit template instantiation @relates HeuristicTree */ -template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const; +template std::pair<bool, GEMMConfigReshapedOnlyRHS> +HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const; /** Explicit template instantiation @relates HeuristicTree */ template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const; |