aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/mlgo
diff options
context:
space:
mode:
authorFelix Thomasmathibalan <felixjohnny.thomasmathibalan@arm.com>2023-09-27 17:46:17 +0100
committerfelixjohnny.thomasmathibalan <felixjohnny.thomasmathibalan@arm.com>2023-09-28 12:08:05 +0000
commitafd38f0c617d6f89b2b4532c6c44f116617e2b6f (patch)
tree03bc7d5a762099989b16a656fa8d397b490ed70e /src/runtime/CL/mlgo
parentbdcb4c148ee2fdeaaddf4cf1e57bbb0de02bb894 (diff)
downloadComputeLibrary-afd38f0c617d6f89b2b4532c6c44f116617e2b6f.tar.gz
Apply clang-format on repository
Code is formatted as per a revised clang format configuration file(not part of this delivery). Version 14.0.6 is used. Exclusion List: - files with .cl extension - files that are not strictly C/C++ (e.g. Android.bp, Sconscript ...) And the following directories - compute_kernel_writer/validation/ - tests/ - include/ - src/core/NEON/kernels/convolution/ - src/core/NEON/kernels/arm_gemm/ - src/core/NEON/kernels/arm_conv/ - data/ There will be a follow up for formatting of .cl files and the files under tests/ and compute_kernel_writer/validation/. Signed-off-by: Felix Thomasmathibalan <felixjohnny.thomasmathibalan@arm.com> Change-Id: Ib7eb1fcf4e7537b9feaefcfc15098a804a3fde0a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10391 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Diffstat (limited to 'src/runtime/CL/mlgo')
-rw-r--r--src/runtime/CL/mlgo/Common.h40
-rw-r--r--src/runtime/CL/mlgo/HeuristicTree.cpp89
-rw-r--r--src/runtime/CL/mlgo/HeuristicTree.h24
-rw-r--r--src/runtime/CL/mlgo/MLGOHeuristics.cpp99
-rw-r--r--src/runtime/CL/mlgo/MLGOHeuristics.h6
-rw-r--r--src/runtime/CL/mlgo/MLGOParser.cpp188
-rw-r--r--src/runtime/CL/mlgo/MLGOParser.h9
-rw-r--r--src/runtime/CL/mlgo/Utils.cpp48
-rw-r--r--src/runtime/CL/mlgo/Utils.h10
9 files changed, 250 insertions, 263 deletions
diff --git a/src/runtime/CL/mlgo/Common.h b/src/runtime/CL/mlgo/Common.h
index c451bd9062..08a7ee8c18 100644
--- a/src/runtime/CL/mlgo/Common.h
+++ b/src/runtime/CL/mlgo/Common.h
@@ -45,37 +45,37 @@ using GEMMType = CLGEMMKernelType;
/** GEMM Configuration for Native kernel */
struct GEMMConfigNative
{
- unsigned int m0{ 1 }; /**< Number of rows processed by the matrix multiplication */
- unsigned int n0{ 1 }; /**< Number of columns processed by the matrix multiplication */
- unsigned int k0{ 1 }; /**< Number of partial accumulations performed by the matrix multiplication */
+ unsigned int m0{1}; /**< Number of rows processed by the matrix multiplication */
+ unsigned int n0{1}; /**< Number of columns processed by the matrix multiplication */
+ unsigned int k0{1}; /**< Number of partial accumulations performed by the matrix multiplication */
};
/** GEMM Configuration for Reshaped Only RHS kernel */
struct GEMMConfigReshapedOnlyRHS
{
- unsigned int m0{ 1 }; /**< Number of rows processed by the matrix multiplication */
- unsigned int n0{ 1 }; /**< Number of columns processed by the matrix multiplication */
- unsigned int k0{ 1 }; /**< Number of partial accumulations performed by the matrix multiplication */
- unsigned int h0{ 1 }; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
- bool interleave_rhs{ false }; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */
- bool transpose_rhs{ false }; /**< True if the (k0xn0) block has to be transposed before been stored */
- bool export_cl_image{ false }; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */
+ unsigned int m0{1}; /**< Number of rows processed by the matrix multiplication */
+ unsigned int n0{1}; /**< Number of columns processed by the matrix multiplication */
+ unsigned int k0{1}; /**< Number of partial accumulations performed by the matrix multiplication */
+ unsigned int h0{1}; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
+ bool interleave_rhs{false}; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */
+ bool transpose_rhs{false}; /**< True if the (k0xn0) block has to be transposed before been stored */
+ bool export_cl_image{false}; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */
};
/** GEMM Configuration for Reshaped kernel */
struct GEMMConfigReshaped
{
- unsigned int m0{ 1 }; /**< Number of rows processed by the matrix multiplication */
- unsigned int n0{ 1 }; /**< Number of columns processed by the matrix multiplication */
- unsigned int k0{ 1 }; /**< Number of partial accumulations performed by the matrix multiplication */
- unsigned int v0{ 1 }; /**< Number of vertical blocks of size (m0xk0) stored on the same output row */
- unsigned int h0{ 1 }; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
- bool interleave_lhs{ false }; /**< True if the v0 (m0xk0) blocks have to be interleaved in the output row */
- bool interleave_rhs{ false }; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */
- bool transpose_rhs{ false }; /**< True if the (k0xn0) block has to be transposed before been stored */
- bool export_cl_image{ false }; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */
+ unsigned int m0{1}; /**< Number of rows processed by the matrix multiplication */
+ unsigned int n0{1}; /**< Number of columns processed by the matrix multiplication */
+ unsigned int k0{1}; /**< Number of partial accumulations performed by the matrix multiplication */
+ unsigned int v0{1}; /**< Number of vertical blocks of size (m0xk0) stored on the same output row */
+ unsigned int h0{1}; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
+ bool interleave_lhs{false}; /**< True if the v0 (m0xk0) blocks have to be interleaved in the output row */
+ bool interleave_rhs{false}; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */
+ bool transpose_rhs{false}; /**< True if the (k0xn0) block has to be transposed before been stored */
+ bool export_cl_image{false}; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */
};
} // namespace mlgo
} // namespace arm_compute
-#endif // SRC_RUNTIME_CL_MLGO_COMMON_H \ No newline at end of file
+#endif // SRC_RUNTIME_CL_MLGO_COMMON_H
diff --git a/src/runtime/CL/mlgo/HeuristicTree.cpp b/src/runtime/CL/mlgo/HeuristicTree.cpp
index 1c75cdc427..f7b706902b 100644
--- a/src/runtime/CL/mlgo/HeuristicTree.cpp
+++ b/src/runtime/CL/mlgo/HeuristicTree.cpp
@@ -22,6 +22,7 @@
* SOFTWARE.
*/
#include "src/runtime/CL/mlgo/HeuristicTree.h"
+
#include "arm_compute/core/Log.h"
#include "support/Cast.h"
@@ -40,27 +41,23 @@ bool evaluate(GEMMShape shape, Condition cond)
// PRE: all features and ConditionalOps are valid
constexpr float eps = 0.0001f;
// Calculate all secondary features
- std::vector<std::pair<std::string, float>> cond_values
- {
- { "m", static_cast<float>(shape.m) },
- { "n", static_cast<float>(shape.n) },
- { "k", static_cast<float>(shape.k) },
- { "b", static_cast<float>(shape.b) },
- { "r_mn", static_cast<float>(shape.m) / shape.n },
- { "r_mk", static_cast<float>(shape.m) / shape.k },
- { "r_nk", static_cast<float>(shape.n) / shape.k },
- { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
- { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
- };
- auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
- [&cond](decltype(*cond_values.begin()) it)
- {
- return it.first == cond.feature;
- });
+ std::vector<std::pair<std::string, float>> cond_values{
+ {"m", static_cast<float>(shape.m)},
+ {"n", static_cast<float>(shape.n)},
+ {"k", static_cast<float>(shape.k)},
+ {"b", static_cast<float>(shape.b)},
+ {"r_mn", static_cast<float>(shape.m) / shape.n},
+ {"r_mk", static_cast<float>(shape.m) / shape.k},
+ {"r_nk", static_cast<float>(shape.n) / shape.k},
+ {"r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k)},
+ {"workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0}};
+ auto cond_value_pair_it =
+ std::find_if(cond_values.begin(), cond_values.end(),
+ [&cond](decltype(*cond_values.begin()) it) { return it.first == cond.feature; });
ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
const float cond_value = cond_value_pair_it->second;
- switch(cond.op)
+ switch (cond.op)
{
case ConditionalOp::LT:
{
@@ -92,13 +89,12 @@ constexpr size_t HeuristicTree::_max_num_nodes;
constexpr size_t HeuristicTree::_max_query_depth;
constexpr HeuristicTree::NodeID HeuristicTree::_root;
-HeuristicTree::HeuristicTree()
- : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
+HeuristicTree::HeuristicTree() : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
{
}
HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
- : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
+ : _id{id}, _heuristic_type{h_type}, _ip_target{ip_target}, _data_type{data_type}, _tree{}
{
}
@@ -108,16 +104,17 @@ std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
// Root ID = 0;
auto cur_node = _tree.at(_root).get();
size_t depth = 0;
- while(cur_node->type() != NodeType::Leaf)
+ while (cur_node->type() != NodeType::Leaf)
{
- if(depth > _max_query_depth)
+ if (depth > _max_query_depth)
{
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?",
+ _max_query_depth);
return std::make_pair(false, T{});
}
ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
- if(evaluate(shape, br_node->condition))
+ if (evaluate(shape, br_node->condition))
{
cur_node = _tree.at(br_node->true_node).get();
}
@@ -135,12 +132,12 @@ std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
template <typename T>
bool HeuristicTree::add_leaf(NodeID id, T val)
{
- if(_tree.size() >= _max_num_nodes)
+ if (_tree.size() >= _max_num_nodes)
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
return false;
}
- if(_tree.find(id) != _tree.end())
+ if (_tree.find(id) != _tree.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
return false;
@@ -151,28 +148,23 @@ bool HeuristicTree::add_leaf(NodeID id, T val)
bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
{
- if(_tree.size() >= _max_num_nodes)
+ if (_tree.size() >= _max_num_nodes)
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
return false;
}
- const std::set<std::string> supported_features =
- {
- "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
- };
- const auto orig_feature = cond.feature;
- std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
- {
- return std::tolower(c);
- });
- if(supported_features.find(cond.feature) == supported_features.end())
+ const std::set<std::string> supported_features = {"m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"};
+ const auto orig_feature = cond.feature;
+ std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(),
+ [](char c) { return std::tolower(c); });
+ if (supported_features.find(cond.feature) == supported_features.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
return false;
}
- if(_tree.find(id) != _tree.end())
+ if (_tree.find(id) != _tree.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
return false;
@@ -184,32 +176,32 @@ bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID
bool HeuristicTree::check_if_structurally_correct() const
{
std::set<NodeID> visited;
- std::deque<NodeID> to_visit{ _root };
+ std::deque<NodeID> to_visit{_root};
- while(!to_visit.empty())
+ while (!to_visit.empty())
{
auto id = to_visit.front();
to_visit.pop_front();
- if(_tree.find(id) == _tree.end())
+ if (_tree.find(id) == _tree.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
return false;
}
auto not_seen_before = visited.insert(id);
- if(!not_seen_before.second)
+ if (!not_seen_before.second)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
return false;
}
auto cur_node = _tree.at(id).get();
- if(cur_node->type() == NodeType::Branch)
+ if (cur_node->type() == NodeType::Branch)
{
auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
to_visit.push_back(br_node->true_node);
to_visit.push_back(br_node->false_node);
}
}
- if(visited.size() != _tree.size())
+ if (visited.size() != _tree.size())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
return false;
@@ -219,12 +211,12 @@ bool HeuristicTree::check_if_structurally_correct() const
bool HeuristicTree::check()
{
- if(_tree.empty())
+ if (_tree.empty())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
return false;
}
- if(_tree.find(_root) == _tree.end())
+ if (_tree.find(_root) == _tree.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
return false;
@@ -237,7 +229,8 @@ template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shap
/** Explicit template instantiation @relates HeuristicTree */
template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
/** Explicit template instantiation @relates HeuristicTree */
-template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
+template std::pair<bool, GEMMConfigReshapedOnlyRHS>
+HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
/** Explicit template instantiation @relates HeuristicTree */
template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
diff --git a/src/runtime/CL/mlgo/HeuristicTree.h b/src/runtime/CL/mlgo/HeuristicTree.h
index d5c7de2215..a4f8c116b9 100644
--- a/src/runtime/CL/mlgo/HeuristicTree.h
+++ b/src/runtime/CL/mlgo/HeuristicTree.h
@@ -25,6 +25,7 @@
#define SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H
#include "arm_compute/core/Types.h"
+
#include "src/runtime/CL/mlgo/Common.h"
#include <map>
@@ -84,7 +85,7 @@ public:
struct BranchNode : public Node
{
BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
- : id{ id }, condition{ cond }, true_node{ t_node }, false_node{ f_node }
+ : id{id}, condition{cond}, true_node{t_node}, false_node{f_node}
{
}
NodeType type() const override
@@ -100,8 +101,7 @@ public:
template <typename T>
struct LeafNode : public Node
{
- LeafNode(NodeID id, T val)
- : id{ id }, value{ val }
+ LeafNode(NodeID id, T val) : id{id}, value{val}
{
}
NodeType type() const override
@@ -177,22 +177,22 @@ public:
bool check();
private:
- static constexpr size_t _max_query_depth{ 1000 }; // Maximum depth of query
- static constexpr size_t _max_num_nodes{ 100000 }; // Maximum number of nodes contained by the tree
- static constexpr NodeID _root{ 0 }; // Root tree ID
+ static constexpr size_t _max_query_depth{1000}; // Maximum depth of query
+ static constexpr size_t _max_num_nodes{100000}; // Maximum number of nodes contained by the tree
+ static constexpr NodeID _root{0}; // Root tree ID
private:
bool check_if_structurally_correct() const;
private:
- TreeID _id; /**< Heuristic tree ID */
- HeuristicType _heuristic_type; /**< Heuristic type */
- std::string _ip_target; /**< IP target associated with the tree */
- DataType _data_type; /**< Data type associated with the tree */
- std::map<NodeID, std::unique_ptr<Node>> _tree; /**< Tree representation */
+ TreeID _id; /**< Heuristic tree ID */
+ HeuristicType _heuristic_type; /**< Heuristic type */
+ std::string _ip_target; /**< IP target associated with the tree */
+ DataType _data_type; /**< Data type associated with the tree */
+ std::map<NodeID, std::unique_ptr<Node>> _tree; /**< Tree representation */
};
} // namespace mlgo
} // namespace arm_compute
-#endif //SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H \ No newline at end of file
+#endif //SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H
diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.cpp b/src/runtime/CL/mlgo/MLGOHeuristics.cpp
index 80f3bb85e9..aed46cd80f 100644
--- a/src/runtime/CL/mlgo/MLGOHeuristics.cpp
+++ b/src/runtime/CL/mlgo/MLGOHeuristics.cpp
@@ -24,6 +24,7 @@
#include "src/runtime/CL/mlgo/MLGOHeuristics.h"
#include "arm_compute/core/Log.h"
+
#include "src/runtime/CL/mlgo/MLGOParser.h"
#include "src/runtime/CL/mlgo/Utils.h"
@@ -39,19 +40,19 @@ bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
}
bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs)
{
- return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs,
- rhs.export_cl_image);
+ return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) ==
+ std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image);
}
bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs)
{
- return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0,
- rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image);
+ return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs,
+ lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0, rhs.interleave_lhs,
+ rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image);
}
constexpr size_t MLGOHeuristics::_max_num_trees;
-MLGOHeuristics::MLGOHeuristics()
- : _indices{}, _trees{}, _tree_valid{}, _valid{ false }
+MLGOHeuristics::MLGOHeuristics() : _indices{}, _trees{}, _tree_valid{}, _valid{false}
{
}
@@ -59,71 +60,74 @@ std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(const Query &query) co
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str());
const auto invalid = GEMMType::RESHAPED;
- if(!_valid)
+ if (!_valid)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
- return { false, invalid };
+ return {false, invalid};
}
auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type);
- GEMMShape shape_query{ query.m, query.n, query.k, query.b };
- if(_trees.find(index) == _trees.end())
+ GEMMShape shape_query{query.m, query.n, query.k, query.b};
+ if (_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
- return { false, invalid };
+ return {false, invalid};
}
return _trees.at(index).query<GEMMType>(shape_query);
}
std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(const Query &query) const
{
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str());
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.",
+ to_string(query).c_str());
const auto invalid = GEMMConfigNative{};
- if(!_valid)
+ if (!_valid)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
- return { false, invalid };
+ return {false, invalid};
}
auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type);
- GEMMShape shape_query{ query.m, query.n, query.k, query.b };
- if(_trees.find(index) == _trees.end())
+ GEMMShape shape_query{query.m, query.n, query.k, query.b};
+ if (_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
- return { false, invalid };
+ return {false, invalid};
}
return _trees.at(index).query<GEMMConfigNative>(shape_query);
}
std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(const Query &query) const
{
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str());
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.",
+ to_string(query).c_str());
const auto invalid = GEMMConfigReshapedOnlyRHS{};
- if(!_valid)
+ if (!_valid)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
- return { false, invalid };
+ return {false, invalid};
}
auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type);
- GEMMShape shape_query{ query.m, query.n, query.k, query.b };
- if(_trees.find(index) == _trees.end())
+ GEMMShape shape_query{query.m, query.n, query.k, query.b};
+ if (_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
- return { false, invalid };
+ return {false, invalid};
}
return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query);
}
std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(const Query &query) const
{
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str());
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.",
+ to_string(query).c_str());
const auto invalid = GEMMConfigReshaped{};
- if(!_valid)
+ if (!_valid)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
- return { false, invalid };
+ return {false, invalid};
}
auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type);
- GEMMShape shape_query{ query.m, query.n, query.k, query.b };
- if(_trees.find(index) == _trees.end())
+ GEMMShape shape_query{query.m, query.n, query.k, query.b};
+ if (_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
- return { false, invalid };
+ return {false, invalid};
}
return _trees.at(index).query<GEMMConfigReshaped>(shape_query);
}
@@ -131,14 +135,14 @@ std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(c
bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id)
{
bool status;
- HeuristicTree *tree{ nullptr };
+ HeuristicTree *tree{nullptr};
std::tie(status, tree) = get_heuristic_tree(id);
- if(!status)
+ if (!status)
{
return status;
}
status = tree->check();
- if(!status)
+ if (!status)
{
return status;
}
@@ -149,14 +153,12 @@ bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id)
bool MLGOHeuristics::check_all() const
{
// Tree validities are already checked and cached.
- bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v)
- {
- return !v.second;
- })
- == _tree_valid.end();
- if(!all_trees_are_checked)
+ bool all_trees_are_checked =
+ std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v) { return !v.second; }) == _tree_valid.end();
+ if (!all_trees_are_checked)
{
- ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each "
+ "tree is completed. This could also indicate there are no trees in the dotmlgo");
return false;
}
@@ -167,14 +169,14 @@ bool MLGOHeuristics::check_all() const
std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id)
{
- if(_indices.find(id) == _indices.end())
+ if (_indices.find(id) == _indices.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id);
return std::make_pair(false, nullptr);
}
const auto index = _indices[id];
- if(_trees.find(index) == _trees.end())
+ if (_trees.find(index) == _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
return std::make_pair(false, nullptr);
@@ -186,7 +188,7 @@ std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTre
bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t)
{
- if(_indices.size() >= _max_num_trees)
+ if (_indices.size() >= _max_num_trees)
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees);
return false;
@@ -194,7 +196,7 @@ bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t)
// PRE: correctness of t is guaranteed by the tree construction process
// Ensure unique id
const auto id = t.id();
- if(_indices.find(id) != _indices.end())
+ if (_indices.find(id) != _indices.end())
{
ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id);
return false;
@@ -202,7 +204,7 @@ bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t)
// Ensure unique index
const auto index = t.index();
- if(_trees.find(index) != _trees.end())
+ if (_trees.find(index) != _trees.end())
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists");
return false;
@@ -219,9 +221,10 @@ bool MLGOHeuristics::reload_from_file(const std::string &filename)
std::ifstream fs;
fs.exceptions(std::ifstream::badbit);
fs.open(filename, std::ios::in);
- if(!fs.is_open())
+ if (!fs.is_open())
{
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str());
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead",
+ filename.c_str());
return _valid = false;
}
return reload_from_stream(fs);
@@ -230,7 +233,7 @@ bool MLGOHeuristics::reload_from_file(const std::string &filename)
bool MLGOHeuristics::reload_from_stream(std::istream &in)
{
auto parsed = parser::parse_mlgo(in);
- if(!parsed.first)
+ if (!parsed.first)
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead");
return _valid = false;
@@ -241,4 +244,4 @@ bool MLGOHeuristics::reload_from_stream(std::istream &in)
}
} // namespace mlgo
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute
diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.h b/src/runtime/CL/mlgo/MLGOHeuristics.h
index aa21225959..6a491c5503 100644
--- a/src/runtime/CL/mlgo/MLGOHeuristics.h
+++ b/src/runtime/CL/mlgo/MLGOHeuristics.h
@@ -135,16 +135,16 @@ public:
bool check_all() const;
private:
- static constexpr size_t _max_num_trees{ 100 }; /**< Max number of trees that can be added*/
+ static constexpr size_t _max_num_trees{100}; /**< Max number of trees that can be added*/
private:
// There exists a one-to-one mappipng between TreeID and Index, either can be used to identify a @ref HeuristicTree
std::map<HeuristicTree::TreeID, HeuristicTree::Index> _indices; /**< A mapping from TreeID to Index */
std::map<HeuristicTree::Index, HeuristicTree> _trees; /**< A mapping from Index to HeuristicTree */
std::map<HeuristicTree::TreeID, bool> _tree_valid; /**< Result cache of the tree validity checks */
- bool _valid; /**< Overall validity */
+ bool _valid; /**< Overall validity */
};
} // namespace mlgo
} // namespace arm_compute
-#endif //SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H \ No newline at end of file
+#endif //SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H
diff --git a/src/runtime/CL/mlgo/MLGOParser.cpp b/src/runtime/CL/mlgo/MLGOParser.cpp
index 625739e450..893daf2ed9 100644
--- a/src/runtime/CL/mlgo/MLGOParser.cpp
+++ b/src/runtime/CL/mlgo/MLGOParser.cpp
@@ -22,19 +22,21 @@
* SOFTWARE.
*/
#include "src/runtime/CL/mlgo/MLGOParser.h"
+
#include "arm_compute/core/Log.h"
+
#include "src/runtime/CL/mlgo/Utils.h"
#include <sstream>
#define CHECK(parser_expr, valid_var) \
(parser_expr); \
- if(!valid_var) \
+ if (!valid_var) \
return;
#define CHECK_DEFAULT(parser_expr, valid_var, default_val) \
(parser_expr); \
- if(!valid_var) \
+ if (!valid_var) \
return default_val;
#ifdef ARM_COMPUTE_LOGGING_ENABLED
@@ -53,8 +55,7 @@
valid_var = false; \
return default_val;
-#define LOG_TOKEN_POS(tokens, pos_var) \
- const auto pos_var = tokens.current_pos();
+#define LOG_TOKEN_POS(tokens, pos_var) const auto pos_var = tokens.current_pos();
#else // ARM_COMPUTE_LOGGING_ENABLED
@@ -73,19 +74,12 @@ namespace
{
void ltrim(std::string &str)
{
- str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char ch)
- {
- return !std::isspace(ch);
- }));
+ str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char ch) { return !std::isspace(ch); }));
}
void rtrim(std::string &str)
{
- str.erase(std::find_if(str.rbegin(), str.rend(), [](char ch)
- {
- return !std::isspace(ch);
- }).base(),
- str.end());
+ str.erase(std::find_if(str.rbegin(), str.rend(), [](char ch) { return !std::isspace(ch); }).base(), str.end());
}
void trim(std::string &str)
@@ -109,7 +103,7 @@ enum class ComparatorType
};
TokenStream::TokenStream(std::istream &s, const std::string &delims)
- : _delims{ delims }, _istream{ s }, _tokens{}, _lookahead_pos{}
+ : _delims{delims}, _istream{s}, _tokens{}, _lookahead_pos{}
{
read();
}
@@ -125,7 +119,7 @@ Token TokenStream::take()
ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty");
Token t = _tokens.front();
_tokens.pop_front();
- if(_tokens.empty())
+ if (_tokens.empty())
{
read();
}
@@ -136,7 +130,7 @@ Token TokenStream::peek(size_t i)
ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty");
ARM_COMPUTE_ERROR_ON_MSG(i >= max_look_ahead, "TokenStream: Exceeding max look ahead");
// NOTE: If i exceeds the stream (_istream.eof()), read() automatically appends a End token at the end
- while(_istream && _tokens.size() <= i)
+ while (_istream && _tokens.size() <= i)
{
read();
}
@@ -146,7 +140,7 @@ Token TokenStream::peek(size_t i)
void advance(CharPosition &pos, char ch)
{
- if(ch == '\n')
+ if (ch == '\n')
{
pos.ln += 1;
pos.col = 0;
@@ -167,17 +161,16 @@ void TokenStream::read()
do
{
// Reached eof
- if(!_istream.get(ch))
+ if (!_istream.get(ch))
{
- if(!reached_end())
+ if (!reached_end())
{
_tokens.emplace_back(TokenType::End, "", _lookahead_pos);
}
return;
}
advance(_lookahead_pos, ch);
- }
- while(std::isspace(ch) || is_delim(ch));
+ } while (std::isspace(ch) || is_delim(ch));
// Read chars until we hit a delim or eof
auto orig_pos = _lookahead_pos;
auto tok = recognize_tok(ch);
@@ -190,41 +183,41 @@ void TokenStream::read()
Token TokenStream::recognize_tok(char ch)
{
- if(ch == '[')
+ if (ch == '[')
{
- return Token{ TokenType::L_List, "", _lookahead_pos };
+ return Token{TokenType::L_List, "", _lookahead_pos};
}
- else if(ch == ']')
+ else if (ch == ']')
{
- return Token{ TokenType::R_List, "", _lookahead_pos };
+ return Token{TokenType::R_List, "", _lookahead_pos};
}
- else if(ch == '.')
+ else if (ch == '.')
{
- return float_after_dp_st(std::string{ ch });
+ return float_after_dp_st(std::string{ch});
}
- else if(std::isdigit(ch))
+ else if (std::isdigit(ch))
{
- return num_st(std::string{ ch });
+ return num_st(std::string{ch});
}
else
{
- return text_st(std::string{ ch });
+ return text_st(std::string{ch});
}
}
Token TokenStream::num_st(std::string value)
{
char ch{};
- while(_istream.get(ch))
+ while (_istream.get(ch))
{
advance(_lookahead_pos, ch);
- if(ch == '.')
+ if (ch == '.')
{
return float_after_dp_st(value + ch);
}
- else if(!std::isdigit(ch))
+ else if (!std::isdigit(ch))
{
- if(!is_delim(ch) && !std::isspace(ch))
+ if (!is_delim(ch) && !std::isspace(ch))
{
rewind(_lookahead_pos);
_istream.unget();
@@ -233,18 +226,18 @@ Token TokenStream::num_st(std::string value)
}
value += ch;
}
- return Token{ TokenType::Int, value, _lookahead_pos };
+ return Token{TokenType::Int, value, _lookahead_pos};
}
Token TokenStream::float_after_dp_st(std::string value)
{
char ch{};
- while(_istream.get(ch))
+ while (_istream.get(ch))
{
advance(_lookahead_pos, ch);
- if(!std::isdigit(ch))
+ if (!std::isdigit(ch))
{
- if(!is_delim(ch) && !std::isspace(ch))
+ if (!is_delim(ch) && !std::isspace(ch))
{
rewind(_lookahead_pos);
_istream.unget();
@@ -253,20 +246,20 @@ Token TokenStream::float_after_dp_st(std::string value)
}
value += ch;
}
- return Token{ TokenType::Float, value, _lookahead_pos };
+ return Token{TokenType::Float, value, _lookahead_pos};
}
Token TokenStream::text_st(std::string value)
{
char ch{};
- while(_istream.get(ch))
+ while (_istream.get(ch))
{
advance(_lookahead_pos, ch);
- if(is_delim(ch))
+ if (is_delim(ch))
{
break;
}
- if(ch == '[' || ch == ']')
+ if (ch == '[' || ch == ']')
{
rewind(_lookahead_pos);
_istream.unget();
@@ -274,7 +267,7 @@ Token TokenStream::text_st(std::string value)
}
value += ch;
}
- return Token{ TokenType::Text, value, _lookahead_pos };
+ return Token{TokenType::Text, value, _lookahead_pos};
}
bool TokenStream::reached_end() const
@@ -291,7 +284,7 @@ void end(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::End)
+ if (tok.type != TokenType::End)
{
FAIL_WITH_MSG(valid, pos, "Unexpected token at the end of stream");
}
@@ -301,7 +294,7 @@ bool bool_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::Int)
+ if (tok.type != TokenType::Int)
{
FAIL_WITH_MSG_DEFAULT(valid, false, pos, "Expect bool or int token");
}
@@ -314,7 +307,7 @@ int int_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::Int)
+ if (tok.type != TokenType::Int)
{
FAIL_WITH_MSG_DEFAULT(valid, -1, pos, "Expect int token");
}
@@ -327,7 +320,7 @@ unsigned int uint_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
int val = CHECK_DEFAULT(int_val(in, valid), valid, 0);
- if(val < 0)
+ if (val < 0)
{
FAIL_WITH_MSG_DEFAULT(valid, 0, pos, "Expect unsigned int token");
}
@@ -338,7 +331,7 @@ float float_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::Float)
+ if (tok.type != TokenType::Float)
{
FAIL_WITH_MSG_DEFAULT(valid, 0.f, pos, "Expect float token");
}
@@ -351,7 +344,7 @@ std::string text_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::Text || tok.value.empty())
+ if (tok.type != TokenType::Text || tok.value.empty())
{
FAIL_WITH_MSG_DEFAULT(valid, "", pos, "Expect a non-empty text token");
}
@@ -361,9 +354,9 @@ std::string text_val(TokenStream &in, bool &valid)
bool accept_text(TokenStream &in, const std::string &c_str, bool take = true)
{
auto tok = in.peek();
- if(tok.type == TokenType::Text && tok.value == c_str)
+ if (tok.type == TokenType::Text && tok.value == c_str)
{
- if(take)
+ if (take)
{
in.take();
}
@@ -375,7 +368,7 @@ bool accept_text(TokenStream &in, const std::string &c_str, bool take = true)
void expect_text(TokenStream &in, const std::string &str, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(!accept_text(in, str))
+ if (!accept_text(in, str))
{
FAIL_WITH_MSG(valid, pos, std::string("Expect text token: ") + str);
}
@@ -384,7 +377,7 @@ void expect_text(TokenStream &in, const std::string &str, bool &valid)
bool accept_l_list(TokenStream &in)
{
auto tok = in.peek();
- if(tok.type == TokenType::L_List)
+ if (tok.type == TokenType::L_List)
{
in.take();
return true;
@@ -395,7 +388,7 @@ bool accept_l_list(TokenStream &in)
void expect_l_list(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(!accept_l_list(in))
+ if (!accept_l_list(in))
{
FAIL_WITH_MSG(valid, pos, "Expect '['");
}
@@ -404,7 +397,7 @@ void expect_l_list(TokenStream &in, bool &valid)
bool accept_r_list(TokenStream &in)
{
auto tok = in.peek();
- if(tok.type == TokenType::R_List)
+ if (tok.type == TokenType::R_List)
{
in.take();
return true;
@@ -415,7 +408,7 @@ bool accept_r_list(TokenStream &in)
void expect_r_list(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(!accept_r_list(in))
+ if (!accept_r_list(in))
{
FAIL_WITH_MSG(valid, pos, "Expect ']'");
}
@@ -424,23 +417,23 @@ void expect_r_list(TokenStream &in, bool &valid)
ConditionalOp conditional_op(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "<="))
+ if (accept_text(in, "<="))
{
return ConditionalOp::LE;
}
- else if(accept_text(in, ">="))
+ else if (accept_text(in, ">="))
{
return ConditionalOp::GE;
}
- else if(accept_text(in, "=="))
+ else if (accept_text(in, "=="))
{
return ConditionalOp::EQ;
}
- else if(accept_text(in, "<"))
+ else if (accept_text(in, "<"))
{
return ConditionalOp::LT;
}
- else if(accept_text(in, ">"))
+ else if (accept_text(in, ">"))
{
return ConditionalOp::GT;
}
@@ -464,11 +457,11 @@ void ip_type(TokenStream &in, bool &valid)
{
CHECK(expect_text(in, "ip-type", valid), valid);
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "gpu"))
+ if (accept_text(in, "gpu"))
{
;
}
- else if(accept_text(in, "cpu"))
+ else if (accept_text(in, "cpu"))
{
;
}
@@ -489,15 +482,15 @@ void header(TokenStream &in, bool &valid)
DataType data_type(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "f16"))
+ if (accept_text(in, "f16"))
{
return DataType::F16;
}
- else if(accept_text(in, "f32"))
+ else if (accept_text(in, "f32"))
{
return DataType::F32;
}
- else if(accept_text(in, "qasymm8"))
+ else if (accept_text(in, "qasymm8"))
{
return DataType::QASYMM8;
}
@@ -510,15 +503,15 @@ DataType data_type(TokenStream &in, bool &valid)
ComparatorType comparator_type(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "var"))
+ if (accept_text(in, "var"))
{
return ComparatorType::Var;
}
- else if(accept_text(in, "num"))
+ else if (accept_text(in, "num"))
{
return ComparatorType::Num;
}
- else if(accept_text(in, "enum"))
+ else if (accept_text(in, "enum"))
{
return ComparatorType::Enum;
}
@@ -531,19 +524,19 @@ ComparatorType comparator_type(TokenStream &in, bool &valid)
HeuristicType heuristic_type(TokenStream &in, bool &valid, bool take = true)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "gemm-type", take))
+ if (accept_text(in, "gemm-type", take))
{
return HeuristicType::GEMM_Type;
}
- else if(accept_text(in, "gemm-config-native", take))
+ else if (accept_text(in, "gemm-config-native", take))
{
return HeuristicType::GEMM_Config_Native;
}
- else if(accept_text(in, "gemm-config-reshaped-only-rhs", take))
+ else if (accept_text(in, "gemm-config-reshaped-only-rhs", take))
{
return HeuristicType::GEMM_Config_Reshaped_Only_RHS;
}
- else if(accept_text(in, "gemm-config-reshaped", take))
+ else if (accept_text(in, "gemm-config-reshaped", take))
{
return HeuristicType::GEMM_Config_Reshaped;
}
@@ -557,7 +550,7 @@ void expect_heuristic_type(TokenStream &in, HeuristicType expected_ht, bool &val
{
LOG_TOKEN_POS(in, pos);
auto ht = CHECK(heuristic_type(in, valid, false), valid);
- if(ht != expected_ht)
+ if (ht != expected_ht)
{
FAIL_WITH_MSG(valid, pos, "Unexpected heuristic type");
}
@@ -567,15 +560,15 @@ void expect_heuristic_type(TokenStream &in, HeuristicType expected_ht, bool &val
GEMMType gemm_type(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "native"))
+ if (accept_text(in, "native"))
{
return GEMMType::NATIVE;
}
- else if(accept_text(in, "reshaped-only-rhs"))
+ else if (accept_text(in, "reshaped-only-rhs"))
{
return GEMMType::RESHAPED_ONLY_RHS;
}
- else if(accept_text(in, "reshaped"))
+ else if (accept_text(in, "reshaped"))
{
return GEMMType::RESHAPED;
}
@@ -593,7 +586,7 @@ GEMMConfigNative gemm_config_native(TokenStream &in, bool &valid)
const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
- return GEMMConfigNative{ m0, n0, k0 };
+ return GEMMConfigNative{m0, n0, k0};
}
GEMMConfigReshapedOnlyRHS gemm_config_reshaped_only_rhs(TokenStream &in, bool &valid)
@@ -608,7 +601,7 @@ GEMMConfigReshapedOnlyRHS gemm_config_reshaped_only_rhs(TokenStream &in, bool &v
const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
- return GEMMConfigReshapedOnlyRHS{ m0, n0, k0, h0, ir, tr, ex };
+ return GEMMConfigReshapedOnlyRHS{m0, n0, k0, h0, ir, tr, ex};
}
GEMMConfigReshaped gemm_config_reshaped(TokenStream &in, bool &valid)
@@ -625,17 +618,17 @@ GEMMConfigReshaped gemm_config_reshaped(TokenStream &in, bool &valid)
const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
- return GEMMConfigReshaped{ m0, n0, k0, v0, h0, il, ir, tr, ex };
+ return GEMMConfigReshaped{m0, n0, k0, v0, h0, il, ir, tr, ex};
}
void gpu_priority(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "best-performance"))
+ if (accept_text(in, "best-performance"))
{
;
}
- else if(accept_text(in, "best-memory-usage"))
+ else if (accept_text(in, "best-memory-usage"))
{
;
}
@@ -648,11 +641,11 @@ void gpu_priority(TokenStream &in, bool &valid)
void gpu_behavior(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "static"))
+ if (accept_text(in, "static"))
{
;
}
- else if(accept_text(in, "dynamic"))
+ else if (accept_text(in, "dynamic"))
{
;
}
@@ -665,7 +658,7 @@ void gpu_behavior(TokenStream &in, bool &valid)
void free_vars(TokenStream &in, bool &valid)
{
CHECK(expect_l_list(in, valid), valid);
- while(!accept_r_list(in))
+ while (!accept_r_list(in))
{
CHECK(text_val(in, valid), valid);
}
@@ -688,7 +681,7 @@ void heuristics_table_entry(TokenStream &in, MLGOHeuristics &h, bool &valid)
void heuristics_table(TokenStream &in, MLGOHeuristics &h, bool &valid)
{
CHECK(expect_text(in, "<heuristics-table>", valid), valid);
- while(!accept_text(in, "</heuristics-table>"))
+ while (!accept_text(in, "</heuristics-table>"))
{
CHECK(heuristics_table_entry(in, h, valid), valid);
}
@@ -705,11 +698,12 @@ Condition condition(TokenStream &in, bool &valid)
const auto c_o = CHECK_DEFAULT(conditional_op(in, valid), valid, invalid_val);
const auto r_t = CHECK_DEFAULT(comparator_type(in, valid), valid, invalid_val);
const auto r_v = CHECK_DEFAULT(float_val(in, valid), valid, invalid_val);
- if(l_t != ComparatorType::Var || r_t != ComparatorType::Num)
+ if (l_t != ComparatorType::Var || r_t != ComparatorType::Num)
{
- FAIL_WITH_MSG_DEFAULT(valid, invalid_val, pos, "Only accept LHS type to be Var (string) and RHS type to be Num (float)");
+ FAIL_WITH_MSG_DEFAULT(valid, invalid_val, pos,
+ "Only accept LHS type to be Var (string) and RHS type to be Num (float)");
}
- return Condition{ l_v, c_o, r_v };
+ return Condition{l_v, c_o, r_v};
}
void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
@@ -717,13 +711,13 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
CHECK(expect_text(in, "<heuristic", valid), valid);
const auto tree_id = CHECK(uint_val(in, valid), valid);
CHECK(expect_text(in, ">", valid), valid);
- HeuristicTree *t = nullptr;
- std::tie(valid, t) = CHECK(h.get_heuristic_tree(tree_id), valid);
+ HeuristicTree *t = nullptr;
+ std::tie(valid, t) = CHECK(h.get_heuristic_tree(tree_id), valid);
const HeuristicType t_heuristic_type = std::get<0>(t->index());
- while(!accept_text(in, "</heuristic>"))
+ while (!accept_text(in, "</heuristic>"))
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "b"))
+ if (accept_text(in, "b"))
{
// Branch node
const auto id = CHECK(uint_val(in, valid), valid);
@@ -732,7 +726,7 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
const auto f_id = CHECK(uint_val(in, valid), valid);
valid = CHECK(t->add_branch(id, cond, t_id, f_id), valid);
}
- else if(accept_text(in, "l"))
+ else if (accept_text(in, "l"))
{
// Leaf node
const auto id = CHECK(uint_val(in, valid), valid);
@@ -740,7 +734,7 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
// heuristic table). For now it remains as a step for validation.
LOG_TOKEN_POS(in, pos);
CHECK(expect_heuristic_type(in, t_heuristic_type, valid), valid);
- switch(t_heuristic_type)
+ switch (t_heuristic_type)
{
case HeuristicType::GEMM_Type:
{
@@ -786,7 +780,7 @@ MLGOHeuristics mlgo(TokenStream &in, bool &valid)
MLGOHeuristics h;
CHECK_DEFAULT(header(in, valid), valid, h);
CHECK_DEFAULT(heuristics_table(in, h, valid), valid, h);
- while(accept_text(in, "<heuristic", false))
+ while (accept_text(in, "<heuristic", false))
{
CHECK_DEFAULT(heuristic_tree(in, h, valid), valid, h);
}
@@ -809,4 +803,4 @@ std::pair<bool, MLGOHeuristics> parse_mlgo(std::istream &in)
#undef CHECK
#undef CHECK_DEFAULT
#undef FAIL_WITH_MSG
-#undef FAIL_WITH_MSG_DEFAULT \ No newline at end of file
+#undef FAIL_WITH_MSG_DEFAULT
diff --git a/src/runtime/CL/mlgo/MLGOParser.h b/src/runtime/CL/mlgo/MLGOParser.h
index 49d8b9c644..cffce8d6a1 100644
--- a/src/runtime/CL/mlgo/MLGOParser.h
+++ b/src/runtime/CL/mlgo/MLGOParser.h
@@ -98,15 +98,14 @@ struct CharPosition
return ln == other.ln && col == other.col;
}
- size_t ln{ 0 };
- size_t col{ 0 };
+ size_t ln{0};
+ size_t col{0};
};
/** Token */
struct Token
{
- Token(TokenType t, std::string v, CharPosition pos)
- : type{ t }, value{ v }, pos{ pos }
+ Token(TokenType t, std::string v, CharPosition pos) : type{t}, value{v}, pos{pos}
{
}
@@ -196,4 +195,4 @@ std::pair<bool, MLGOHeuristics> parse_mlgo(std::istream &in);
} // namespace parser
} // namespace mlgo
} // namespace arm_compute
-#endif //SRC_RUNTIME_CL_MLGO_MLGO_PARSER_H \ No newline at end of file
+#endif //SRC_RUNTIME_CL_MLGO_MLGO_PARSER_H
diff --git a/src/runtime/CL/mlgo/Utils.cpp b/src/runtime/CL/mlgo/Utils.cpp
index 81d418c28e..c7e0100b3c 100644
--- a/src/runtime/CL/mlgo/Utils.cpp
+++ b/src/runtime/CL/mlgo/Utils.cpp
@@ -43,40 +43,38 @@ inline std::string to_str(const T &val)
std::ostream &operator<<(std::ostream &os, const GEMMConfigNative &config)
{
return os << "Native:{"
- << "m0: " << config.m0 << ", "
- << "n0: " << config.n0 << ", "
- << "k0: " << config.k0 << ", "
- << "}";
+ << "m0: " << config.m0 << ", "
+ << "n0: " << config.n0 << ", "
+ << "k0: " << config.k0 << ", "
+ << "}";
}
std::ostream &operator<<(std::ostream &os, const GEMMConfigReshapedOnlyRHS &config)
{
return os << "ReshapedOnlyRHS:{"
- << "m0: " << config.m0 << ", "
- << "n0: " << config.n0 << ", "
- << "k0: " << config.k0 << ", "
- << "h0: " << config.h0 << ", "
- << "interleave_rhs: " << config.interleave_rhs << ", "
- << "transpose_rhs: " << config.transpose_rhs << ", "
- << "export_cl_image: " << config.export_cl_image
- << "}";
+ << "m0: " << config.m0 << ", "
+ << "n0: " << config.n0 << ", "
+ << "k0: " << config.k0 << ", "
+ << "h0: " << config.h0 << ", "
+ << "interleave_rhs: " << config.interleave_rhs << ", "
+ << "transpose_rhs: " << config.transpose_rhs << ", "
+ << "export_cl_image: " << config.export_cl_image << "}";
}
std::ostream &operator<<(std::ostream &os, const GEMMConfigReshaped &config)
{
return os << "Reshaped:{"
- << "m0: " << config.m0 << ", "
- << "n0: " << config.n0 << ", "
- << "k0: " << config.k0 << ", "
- << "v0: " << config.v0 << ", "
- << "h0: " << config.h0 << ", "
- << "interleave_lhs: " << config.interleave_lhs << ", "
- << "interleave_rhs: " << config.interleave_rhs << ", "
- << "transpose_rhs: " << config.transpose_rhs << ", "
- << "export_cl_image: " << config.export_cl_image
- << "}";
+ << "m0: " << config.m0 << ", "
+ << "n0: " << config.n0 << ", "
+ << "k0: " << config.k0 << ", "
+ << "v0: " << config.v0 << ", "
+ << "h0: " << config.h0 << ", "
+ << "interleave_lhs: " << config.interleave_lhs << ", "
+ << "interleave_rhs: " << config.interleave_rhs << ", "
+ << "transpose_rhs: " << config.transpose_rhs << ", "
+ << "export_cl_image: " << config.export_cl_image << "}";
}
std::ostream &operator<<(std::ostream &os, HeuristicType ht)
{
- switch(ht)
+ switch (ht)
{
case HeuristicType::GEMM_Type:
{
@@ -103,7 +101,7 @@ std::ostream &operator<<(std::ostream &os, HeuristicType ht)
}
std::ostream &operator<<(std::ostream &os, DataType dt)
{
- switch(dt)
+ switch (dt)
{
case DataType::F32:
{
@@ -184,4 +182,4 @@ std::ostream &operator<<(std::ostream &os, const CharPosition &pos)
} // namespace mlgo
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute
diff --git a/src/runtime/CL/mlgo/Utils.h b/src/runtime/CL/mlgo/Utils.h
index c634a887e9..73b537f476 100644
--- a/src/runtime/CL/mlgo/Utils.h
+++ b/src/runtime/CL/mlgo/Utils.h
@@ -43,10 +43,10 @@ std::ostream &operator<<(std::ostream &os, HeuristicType ht);
std::ostream &operator<<(std::ostream &os, DataType dt);
std::ostream &operator<<(std::ostream &os, const HeuristicTree::Index &index);
std::ostream &operator<<(std::ostream &os, const Query &query);
-std::string to_string(const GEMMConfigNative &config);
-std::string to_string(const GEMMConfigReshapedOnlyRHS &config);
-std::string to_string(const GEMMConfigReshaped &config);
-std::string to_string(const Query &query);
+std::string to_string(const GEMMConfigNative &config);
+std::string to_string(const GEMMConfigReshapedOnlyRHS &config);
+std::string to_string(const GEMMConfigReshaped &config);
+std::string to_string(const Query &query);
namespace parser
{
std::ostream &operator<<(std::ostream &os, const CharPosition &pos);
@@ -54,4 +54,4 @@ std::ostream &operator<<(std::ostream &os, const CharPosition &pos);
} // namespace mlgo
} // namespace arm_compute
-#endif //SRC_RUNTIME_CL_MLGO_UTILS_H \ No newline at end of file
+#endif //SRC_RUNTIME_CL_MLGO_UTILS_H