aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/mlgo/HeuristicTree.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/mlgo/HeuristicTree.cpp')
-rw-r--r--src/runtime/CL/mlgo/HeuristicTree.cpp253
1 files changed, 253 insertions, 0 deletions
diff --git a/src/runtime/CL/mlgo/HeuristicTree.cpp b/src/runtime/CL/mlgo/HeuristicTree.cpp
new file mode 100644
index 0000000000..65219998cb
--- /dev/null
+++ b/src/runtime/CL/mlgo/HeuristicTree.cpp
@@ -0,0 +1,253 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/runtime/CL/mlgo/HeuristicTree.h"
+#include "arm_compute/core/Log.h"
+
+#include <algorithm>
+#include <deque>
+#include <set>
+namespace arm_compute
+{
+namespace mlgo
+{
+namespace
+{
+bool evaluate(GEMMShape shape, Condition cond)
+{
+ // PRE: all features and ConditionalOps are valid
+ constexpr float eps = 0.0001f;
+ // Calculate all secondary features
+ std::vector<std::pair<std::string, float>> cond_values
+ {
+ { "m", static_cast<float>(shape.m) },
+ { "n", static_cast<float>(shape.n) },
+ { "k", static_cast<float>(shape.k) },
+ { "b", static_cast<float>(shape.b) },
+ { "r_mn", static_cast<float>(shape.m) / shape.n },
+ { "r_mk", static_cast<float>(shape.m) / shape.k },
+ { "r_nk", static_cast<float>(shape.n) / shape.k },
+ { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
+ { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
+ };
+ auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
+ [&cond](decltype(*cond_values.begin()) it)
+ {
+ return it.first == cond.feature;
+ });
+
+ ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
+ const float cond_value = cond_value_pair_it->second;
+ switch(cond.op)
+ {
+ case ConditionalOp::LT:
+ {
+ return cond_value < cond.threshold;
+ }
+ case ConditionalOp::LE:
+ {
+ return cond_value <= cond.threshold;
+ }
+ case ConditionalOp::GT:
+ {
+ return cond_value > cond.threshold;
+ }
+ case ConditionalOp::GE:
+ {
+ return cond_value >= cond.threshold;
+ }
+ case ConditionalOp::EQ:
+ default:
+ {
+ return std::abs(cond_value - cond.threshold) < eps;
+ }
+ }
+}
+
+} // namespace
+
+constexpr size_t HeuristicTree::_max_num_nodes;
+constexpr size_t HeuristicTree::_max_query_depth;
+constexpr HeuristicTree::NodeID HeuristicTree::_root;
+
+HeuristicTree::HeuristicTree()
+ : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
+{
+}
+
+HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
+ : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
+{
+}
+
+template <typename T>
+std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
+{
+ // Root ID = 0;
+ auto cur_node = _tree.at(_root).get();
+ size_t depth = 0;
+ while(cur_node->type() != NodeType::Leaf)
+ {
+ if(depth > _max_query_depth)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
+ return std::make_pair(false, T{});
+ }
+ ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
+ auto br_node = dynamic_cast<BranchNode *>(cur_node);
+ if(evaluate(shape, br_node->condition))
+ {
+ cur_node = _tree.at(br_node->true_node).get();
+ }
+ else
+ {
+ cur_node = _tree.at(br_node->false_node).get();
+ }
+ ++depth;
+ }
+ ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
+ auto l_node = dynamic_cast<LeafNode<T> *>(cur_node);
+ return std::make_pair(true, l_node->value);
+}
+
+template <typename T>
+bool HeuristicTree::add_leaf(NodeID id, T val)
+{
+ if(_tree.size() >= _max_num_nodes)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
+ return false;
+ }
+ if(_tree.find(id) != _tree.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
+ return false;
+ }
+ _tree[id] = std::make_unique<LeafNode<T>>(id, val);
+ return true;
+}
+
+bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
+{
+ if(_tree.size() >= _max_num_nodes)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
+ return false;
+ }
+
+ const std::set<std::string> supported_features =
+ {
+ "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
+ };
+ const auto orig_feature = cond.feature;
+ std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
+ {
+ return std::tolower(c);
+ });
+ if(supported_features.find(cond.feature) == supported_features.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
+ return false;
+ }
+
+ if(_tree.find(id) != _tree.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
+ return false;
+ }
+ _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
+ return true;
+}
+
+bool HeuristicTree::check_if_structurally_correct() const
+{
+ std::set<NodeID> visited;
+ std::deque<NodeID> to_visit{ _root };
+
+ while(!to_visit.empty())
+ {
+ auto id = to_visit.front();
+ to_visit.pop_front();
+ if(_tree.find(id) == _tree.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
+ return false;
+ }
+ auto not_seen_before = visited.insert(id);
+ if(!not_seen_before.second)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
+ return false;
+ }
+ auto cur_node = _tree.at(id).get();
+ if(cur_node->type() == NodeType::Branch)
+ {
+ auto br_node = dynamic_cast<BranchNode *>(cur_node);
+ to_visit.push_back(br_node->true_node);
+ to_visit.push_back(br_node->false_node);
+ }
+ }
+ if(visited.size() != _tree.size())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
+ return false;
+ }
+ return true;
+}
+
+bool HeuristicTree::check()
+{
+ if(_tree.empty())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
+ return false;
+ }
+ if(_tree.find(_root) == _tree.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
+ return false;
+ }
+ return check_if_structurally_correct();
+}
+
+/** Explicit template instantiation @relates HeuristicTree */
+template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shape) const;
+/** Explicit template instantiation @relates HeuristicTree */
+template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
+/** Explicit template instantiation @relates HeuristicTree */
+template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
+/** Explicit template instantiation @relates HeuristicTree */
+template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
+
+/** Explicit template instantiation @relates HeuristicTree */
+template bool HeuristicTree::add_leaf(NodeID id, GEMMType val);
+/** Explicit template instantiation @relates HeuristicTree */
+template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val);
+/** Explicit template instantiation @relates HeuristicTree */
+template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val);
+/** Explicit template instantiation @relates HeuristicTree */
+template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val);
+
+} // namespace mlgo
+
+} // namespace arm_compute