/* * Copyright (c) 2021 Arm Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #include "src/runtime/CL/mlgo/HeuristicTree.h" #include "arm_compute/core/Log.h" #include "support/Cast.h" #include #include #include namespace arm_compute { namespace mlgo { namespace { bool evaluate(GEMMShape shape, Condition cond) { // PRE: all features and ConditionalOps are valid constexpr float eps = 0.0001f; // Calculate all secondary features std::vector> cond_values { { "m", static_cast(shape.m) }, { "n", static_cast(shape.n) }, { "k", static_cast(shape.k) }, { "b", static_cast(shape.b) }, { "r_mn", static_cast(shape.m) / shape.n }, { "r_mk", static_cast(shape.m) / shape.k }, { "r_nk", static_cast(shape.n) / shape.k }, { "r_mnk", static_cast(shape.m) / (static_cast(shape.n) / shape.k) }, { "workload", (static_cast(shape.m) * shape.n * shape.b) / 20.0 } }; auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(), [&cond](decltype(*cond_values.begin()) it) { return it.first == cond.feature; }); ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end()); const float cond_value = cond_value_pair_it->second; switch(cond.op) { case ConditionalOp::LT: { return cond_value < cond.threshold; } case ConditionalOp::LE: { return cond_value <= cond.threshold; } case ConditionalOp::GT: { return cond_value > cond.threshold; } case ConditionalOp::GE: { return cond_value >= cond.threshold; } case ConditionalOp::EQ: default: { return std::abs(cond_value - cond.threshold) < eps; } } } } // namespace constexpr size_t HeuristicTree::_max_num_nodes; constexpr size_t HeuristicTree::_max_query_depth; constexpr HeuristicTree::NodeID HeuristicTree::_root; HeuristicTree::HeuristicTree() : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32) { } HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type) : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{} { } template std::pair HeuristicTree::query(GEMMShape shape) const { // Root ID = 0; auto cur_node = _tree.at(_root).get(); size_t depth = 0; while(cur_node->type() != NodeType::Leaf) { if(depth > _max_query_depth) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth); return std::make_pair(false, T{}); } ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType"); auto br_node = utils::cast::polymorphic_downcast(cur_node); if(evaluate(shape, br_node->condition)) { cur_node = _tree.at(br_node->true_node).get(); } else { cur_node = _tree.at(br_node->false_node).get(); } ++depth; } ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType"); auto l_node = utils::cast::polymorphic_downcast *>(cur_node); return std::make_pair(true, l_node->value); } template bool HeuristicTree::add_leaf(NodeID id, T val) { if(_tree.size() >= _max_num_nodes) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); return false; } if(_tree.find(id) != _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); return false; } _tree[id] = std::make_unique>(id, val); return true; } bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node) { if(_tree.size() >= _max_num_nodes) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); return false; } const std::set supported_features = { "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload" }; const auto orig_feature = cond.feature; std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c) { return std::tolower(c); }); if(supported_features.find(cond.feature) == supported_features.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str()); return false; } if(_tree.find(id) != _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); return false; } _tree[id] = std::make_unique(id, cond, t_node, f_node); return true; } bool HeuristicTree::check_if_structurally_correct() const { std::set visited; std::deque to_visit{ _root }; while(!to_visit.empty()) { auto id = to_visit.front(); to_visit.pop_front(); if(_tree.find(id) == _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id); return false; } auto not_seen_before = visited.insert(id); if(!not_seen_before.second) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops"); return false; } auto cur_node = _tree.at(id).get(); if(cur_node->type() == NodeType::Branch) { auto br_node = utils::cast::polymorphic_downcast(cur_node); to_visit.push_back(br_node->true_node); to_visit.push_back(br_node->false_node); } } if(visited.size() != _tree.size()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes"); return false; } return true; } bool HeuristicTree::check() { if(_tree.empty()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered"); return false; } if(_tree.find(_root) == _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root); return false; } return check_if_structurally_correct(); } /** Explicit template instantiation @relates HeuristicTree */ template std::pair HeuristicTree::query(GEMMShape shape) const; /** Explicit template instantiation @relates HeuristicTree */ template std::pair HeuristicTree::query(GEMMShape shape) const; /** Explicit template instantiation @relates HeuristicTree */ template std::pair HeuristicTree::query(GEMMShape shape) const; /** Explicit template instantiation @relates HeuristicTree */ template std::pair HeuristicTree::query(GEMMShape shape) const; /** Explicit template instantiation @relates HeuristicTree */ template bool HeuristicTree::add_leaf(NodeID id, GEMMType val); /** Explicit template instantiation @relates HeuristicTree */ template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val); /** Explicit template instantiation @relates HeuristicTree */ template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val); /** Explicit template instantiation @relates HeuristicTree */ template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val); } // namespace mlgo } // namespace arm_compute