From 7061eb283969f9a020c08349454447564e4dd5b3 Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Fri, 8 Jan 2021 15:16:02 +0000 Subject: Implement MLGO module * Implement MLGOHeuristics which provides a query and a loading interface * Implement a top-down parser MLGOParser for parsing dotmlgo * Add validation tests for MLGOHeuristics Resolves COMPMID-3840, COMPMID-3841 Signed-off-by: SiCong Li Change-Id: Iae96d2779524b2dd83623d1a3a30ef57823ae084 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4941 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- src/runtime/CL/mlgo/HeuristicTree.cpp | 253 ++++++++++++++++++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 src/runtime/CL/mlgo/HeuristicTree.cpp (limited to 'src/runtime/CL/mlgo/HeuristicTree.cpp') diff --git a/src/runtime/CL/mlgo/HeuristicTree.cpp b/src/runtime/CL/mlgo/HeuristicTree.cpp new file mode 100644 index 0000000000..65219998cb --- /dev/null +++ b/src/runtime/CL/mlgo/HeuristicTree.cpp @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/CL/mlgo/HeuristicTree.h" +#include "arm_compute/core/Log.h" + +#include +#include +#include +namespace arm_compute +{ +namespace mlgo +{ +namespace +{ +bool evaluate(GEMMShape shape, Condition cond) +{ + // PRE: all features and ConditionalOps are valid + constexpr float eps = 0.0001f; + // Calculate all secondary features + std::vector> cond_values + { + { "m", static_cast(shape.m) }, + { "n", static_cast(shape.n) }, + { "k", static_cast(shape.k) }, + { "b", static_cast(shape.b) }, + { "r_mn", static_cast(shape.m) / shape.n }, + { "r_mk", static_cast(shape.m) / shape.k }, + { "r_nk", static_cast(shape.n) / shape.k }, + { "r_mnk", static_cast(shape.m) / (static_cast(shape.n) / shape.k) }, + { "workload", (static_cast(shape.m) * shape.n * shape.b) / 20.0 } + }; + auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(), + [&cond](decltype(*cond_values.begin()) it) + { + return it.first == cond.feature; + }); + + ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end()); + const float cond_value = cond_value_pair_it->second; + switch(cond.op) + { + case ConditionalOp::LT: + { + return cond_value < cond.threshold; + } + case ConditionalOp::LE: + { + return cond_value <= cond.threshold; + } + case ConditionalOp::GT: + { + return cond_value > cond.threshold; + } + case ConditionalOp::GE: + { + return cond_value >= cond.threshold; + } + case ConditionalOp::EQ: + default: + { + return std::abs(cond_value - cond.threshold) < eps; + } + } +} + +} // namespace + +constexpr size_t HeuristicTree::_max_num_nodes; +constexpr size_t HeuristicTree::_max_query_depth; +constexpr HeuristicTree::NodeID HeuristicTree::_root; + +HeuristicTree::HeuristicTree() + : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32) +{ +} + +HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type) + : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{} +{ +} + +template +std::pair HeuristicTree::query(GEMMShape shape) const +{ + // Root ID = 0; + auto cur_node = _tree.at(_root).get(); + size_t depth = 0; + while(cur_node->type() != NodeType::Leaf) + { + if(depth > _max_query_depth) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth); + return std::make_pair(false, T{}); + } + ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType"); + auto br_node = dynamic_cast(cur_node); + if(evaluate(shape, br_node->condition)) + { + cur_node = _tree.at(br_node->true_node).get(); + } + else + { + cur_node = _tree.at(br_node->false_node).get(); + } + ++depth; + } + ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType"); + auto l_node = dynamic_cast *>(cur_node); + return std::make_pair(true, l_node->value); +} + +template +bool HeuristicTree::add_leaf(NodeID id, T val) +{ + if(_tree.size() >= _max_num_nodes) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); + return false; + } + if(_tree.find(id) != _tree.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); + return false; + } + _tree[id] = std::make_unique>(id, val); + return true; +} + +bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node) +{ + if(_tree.size() >= _max_num_nodes) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); + return false; + } + + const std::set supported_features = + { + "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload" + }; + const auto orig_feature = cond.feature; + std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c) + { + return std::tolower(c); + }); + if(supported_features.find(cond.feature) == supported_features.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str()); + return false; + } + + if(_tree.find(id) != _tree.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); + return false; + } + _tree[id] = std::make_unique(id, cond, t_node, f_node); + return true; +} + +bool HeuristicTree::check_if_structurally_correct() const +{ + std::set visited; + std::deque to_visit{ _root }; + + while(!to_visit.empty()) + { + auto id = to_visit.front(); + to_visit.pop_front(); + if(_tree.find(id) == _tree.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id); + return false; + } + auto not_seen_before = visited.insert(id); + if(!not_seen_before.second) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops"); + return false; + } + auto cur_node = _tree.at(id).get(); + if(cur_node->type() == NodeType::Branch) + { + auto br_node = dynamic_cast(cur_node); + to_visit.push_back(br_node->true_node); + to_visit.push_back(br_node->false_node); + } + } + if(visited.size() != _tree.size()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes"); + return false; + } + return true; +} + +bool HeuristicTree::check() +{ + if(_tree.empty()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered"); + return false; + } + if(_tree.find(_root) == _tree.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root); + return false; + } + return check_if_structurally_correct(); +} + +/** Explicit template instantiation @relates HeuristicTree */ +template std::pair HeuristicTree::query(GEMMShape shape) const; +/** Explicit template instantiation @relates HeuristicTree */ +template std::pair HeuristicTree::query(GEMMShape shape) const; +/** Explicit template instantiation @relates HeuristicTree */ +template std::pair HeuristicTree::query(GEMMShape shape) const; +/** Explicit template instantiation @relates HeuristicTree */ +template std::pair HeuristicTree::query(GEMMShape shape) const; + +/** Explicit template instantiation @relates HeuristicTree */ +template bool HeuristicTree::add_leaf(NodeID id, GEMMType val); +/** Explicit template instantiation @relates HeuristicTree */ +template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val); +/** Explicit template instantiation @relates HeuristicTree */ +template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val); +/** Explicit template instantiation @relates HeuristicTree */ +template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val); + +} // namespace mlgo + +} // namespace arm_compute -- cgit v1.2.1