aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/mlgo/HeuristicTree.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/mlgo/HeuristicTree.cpp')
-rw-r--r--src/runtime/CL/mlgo/HeuristicTree.cpp89
1 files changed, 41 insertions, 48 deletions
diff --git a/src/runtime/CL/mlgo/HeuristicTree.cpp b/src/runtime/CL/mlgo/HeuristicTree.cpp
index 1c75cdc427..f7b706902b 100644
--- a/src/runtime/CL/mlgo/HeuristicTree.cpp
+++ b/src/runtime/CL/mlgo/HeuristicTree.cpp
@@ -22,6 +22,7 @@
* SOFTWARE.
*/
#include "src/runtime/CL/mlgo/HeuristicTree.h"
+
#include "arm_compute/core/Log.h"
#include "support/Cast.h"
@@ -40,27 +41,23 @@ bool evaluate(GEMMShape shape, Condition cond)
// PRE: all features and ConditionalOps are valid
constexpr float eps = 0.0001f;
// Calculate all secondary features
- std::vector<std::pair<std::string, float>> cond_values
- {
- { "m", static_cast<float>(shape.m) },
- { "n", static_cast<float>(shape.n) },
- { "k", static_cast<float>(shape.k) },
- { "b", static_cast<float>(shape.b) },
- { "r_mn", static_cast<float>(shape.m) / shape.n },
- { "r_mk", static_cast<float>(shape.m) / shape.k },
- { "r_nk", static_cast<float>(shape.n) / shape.k },
- { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
- { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
- };
- auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
- [&cond](decltype(*cond_values.begin()) it)
- {
- return it.first == cond.feature;
- });
+ std::vector<std::pair<std::string, float>> cond_values{
+ {"m", static_cast<float>(shape.m)},
+ {"n", static_cast<float>(shape.n)},
+ {"k", static_cast<float>(shape.k)},
+ {"b", static_cast<float>(shape.b)},
+ {"r_mn", static_cast<float>(shape.m) / shape.n},
+ {"r_mk", static_cast<float>(shape.m) / shape.k},
+ {"r_nk", static_cast<float>(shape.n) / shape.k},
+ {"r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k)},
+ {"workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0}};
+ auto cond_value_pair_it =
+ std::find_if(cond_values.begin(), cond_values.end(),
+ [&cond](decltype(*cond_values.begin()) it) { return it.first == cond.feature; });
ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
const float cond_value = cond_value_pair_it->second;
- switch(cond.op)
+ switch (cond.op)
{
case ConditionalOp::LT:
{
@@ -92,13 +89,12 @@ constexpr size_t HeuristicTree::_max_num_nodes;
constexpr size_t HeuristicTree::_max_query_depth;
constexpr HeuristicTree::NodeID HeuristicTree::_root;
-HeuristicTree::HeuristicTree()
- : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
+HeuristicTree::HeuristicTree() : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
{
}
HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
- : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
+ : _id{id}, _heuristic_type{h_type}, _ip_target{ip_target}, _data_type{data_type}, _tree{}
{
}
@@ -108,16 +104,17 @@ std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
// Root ID = 0;
auto cur_node = _tree.at(_root).get();
size_t depth = 0;
- while(cur_node->type() != NodeType::Leaf)
+ while (cur_node->type() != NodeType::Leaf)
{
- if(depth > _max_query_depth)
+ if (depth > _max_query_depth)
{
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?",
+ _max_query_depth);
return std::make_pair(false, T{});
}
ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
- if(evaluate(shape, br_node->condition))
+ if (evaluate(shape, br_node->condition))
{
cur_node = _tree.at(br_node->true_node).get();
}
@@ -135,12 +132,12 @@ std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
template <typename T>
bool HeuristicTree::add_leaf(NodeID id, T val)
{
- if(_tree.size() >= _max_num_nodes)
+ if (_tree.size() >= _max_num_nodes)
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
return false;
}
- if(_tree.find(id) != _tree.end())
+ if (_tree.find(id) != _tree.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
return false;
@@ -151,28 +148,23 @@ bool HeuristicTree::add_leaf(NodeID id, T val)
bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
{
- if(_tree.size() >= _max_num_nodes)
+ if (_tree.size() >= _max_num_nodes)
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
return false;
}
- const std::set<std::string> supported_features =
- {
- "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
- };
- const auto orig_feature = cond.feature;
- std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
- {
- return std::tolower(c);
- });
- if(supported_features.find(cond.feature) == supported_features.end())
+ const std::set<std::string> supported_features = {"m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"};
+ const auto orig_feature = cond.feature;
+ std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(),
+ [](char c) { return std::tolower(c); });
+ if (supported_features.find(cond.feature) == supported_features.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
return false;
}
- if(_tree.find(id) != _tree.end())
+ if (_tree.find(id) != _tree.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
return false;
@@ -184,32 +176,32 @@ bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID
bool HeuristicTree::check_if_structurally_correct() const
{
std::set<NodeID> visited;
- std::deque<NodeID> to_visit{ _root };
+ std::deque<NodeID> to_visit{_root};
- while(!to_visit.empty())
+ while (!to_visit.empty())
{
auto id = to_visit.front();
to_visit.pop_front();
- if(_tree.find(id) == _tree.end())
+ if (_tree.find(id) == _tree.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
return false;
}
auto not_seen_before = visited.insert(id);
- if(!not_seen_before.second)
+ if (!not_seen_before.second)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
return false;
}
auto cur_node = _tree.at(id).get();
- if(cur_node->type() == NodeType::Branch)
+ if (cur_node->type() == NodeType::Branch)
{
auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
to_visit.push_back(br_node->true_node);
to_visit.push_back(br_node->false_node);
}
}
- if(visited.size() != _tree.size())
+ if (visited.size() != _tree.size())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
return false;
@@ -219,12 +211,12 @@ bool HeuristicTree::check_if_structurally_correct() const
bool HeuristicTree::check()
{
- if(_tree.empty())
+ if (_tree.empty())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
return false;
}
- if(_tree.find(_root) == _tree.end())
+ if (_tree.find(_root) == _tree.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
return false;
@@ -237,7 +229,8 @@ template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shap
/** Explicit template instantiation @relates HeuristicTree */
template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
/** Explicit template instantiation @relates HeuristicTree */
-template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
+template std::pair<bool, GEMMConfigReshapedOnlyRHS>
+HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
/** Explicit template instantiation @relates HeuristicTree */
template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;