From afd38f0c617d6f89b2b4532c6c44f116617e2b6f Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 27 Sep 2023 17:46:17 +0100 Subject: Apply clang-format on repository Code is formatted as per a revised clang format configuration file(not part of this delivery). Version 14.0.6 is used. Exclusion List: - files with .cl extension - files that are not strictly C/C++ (e.g. Android.bp, Sconscript ...) And the following directories - compute_kernel_writer/validation/ - tests/ - include/ - src/core/NEON/kernels/convolution/ - src/core/NEON/kernels/arm_gemm/ - src/core/NEON/kernels/arm_conv/ - data/ There will be a follow up for formatting of .cl files and the files under tests/ and compute_kernel_writer/validation/. Signed-off-by: Felix Thomasmathibalan Change-Id: Ib7eb1fcf4e7537b9feaefcfc15098a804a3fde0a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10391 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir --- src/runtime/CL/mlgo/Common.h | 40 +++---- src/runtime/CL/mlgo/HeuristicTree.cpp | 89 +++++++--------- src/runtime/CL/mlgo/HeuristicTree.h | 24 ++--- src/runtime/CL/mlgo/MLGOHeuristics.cpp | 99 ++++++++--------- src/runtime/CL/mlgo/MLGOHeuristics.h | 6 +- src/runtime/CL/mlgo/MLGOParser.cpp | 188 ++++++++++++++++----------------- src/runtime/CL/mlgo/MLGOParser.h | 9 +- src/runtime/CL/mlgo/Utils.cpp | 48 ++++----- src/runtime/CL/mlgo/Utils.h | 10 +- 9 files changed, 250 insertions(+), 263 deletions(-) (limited to 'src/runtime/CL/mlgo') diff --git a/src/runtime/CL/mlgo/Common.h b/src/runtime/CL/mlgo/Common.h index c451bd9062..08a7ee8c18 100644 --- a/src/runtime/CL/mlgo/Common.h +++ b/src/runtime/CL/mlgo/Common.h @@ -45,37 +45,37 @@ using GEMMType = CLGEMMKernelType; /** GEMM Configuration for Native kernel */ struct GEMMConfigNative { - unsigned int m0{ 1 }; /**< Number of rows processed by the matrix multiplication */ - unsigned int n0{ 1 }; /**< Number of columns processed by the matrix multiplication */ - unsigned int k0{ 1 }; /**< Number of partial accumulations performed by the matrix multiplication */ + unsigned int m0{1}; /**< Number of rows processed by the matrix multiplication */ + unsigned int n0{1}; /**< Number of columns processed by the matrix multiplication */ + unsigned int k0{1}; /**< Number of partial accumulations performed by the matrix multiplication */ }; /** GEMM Configuration for Reshaped Only RHS kernel */ struct GEMMConfigReshapedOnlyRHS { - unsigned int m0{ 1 }; /**< Number of rows processed by the matrix multiplication */ - unsigned int n0{ 1 }; /**< Number of columns processed by the matrix multiplication */ - unsigned int k0{ 1 }; /**< Number of partial accumulations performed by the matrix multiplication */ - unsigned int h0{ 1 }; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */ - bool interleave_rhs{ false }; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */ - bool transpose_rhs{ false }; /**< True if the (k0xn0) block has to be transposed before been stored */ - bool export_cl_image{ false }; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */ + unsigned int m0{1}; /**< Number of rows processed by the matrix multiplication */ + unsigned int n0{1}; /**< Number of columns processed by the matrix multiplication */ + unsigned int k0{1}; /**< Number of partial accumulations performed by the matrix multiplication */ + unsigned int h0{1}; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */ + bool interleave_rhs{false}; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */ + bool transpose_rhs{false}; /**< True if the (k0xn0) block has to be transposed before been stored */ + bool export_cl_image{false}; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */ }; /** GEMM Configuration for Reshaped kernel */ struct GEMMConfigReshaped { - unsigned int m0{ 1 }; /**< Number of rows processed by the matrix multiplication */ - unsigned int n0{ 1 }; /**< Number of columns processed by the matrix multiplication */ - unsigned int k0{ 1 }; /**< Number of partial accumulations performed by the matrix multiplication */ - unsigned int v0{ 1 }; /**< Number of vertical blocks of size (m0xk0) stored on the same output row */ - unsigned int h0{ 1 }; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */ - bool interleave_lhs{ false }; /**< True if the v0 (m0xk0) blocks have to be interleaved in the output row */ - bool interleave_rhs{ false }; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */ - bool transpose_rhs{ false }; /**< True if the (k0xn0) block has to be transposed before been stored */ - bool export_cl_image{ false }; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */ + unsigned int m0{1}; /**< Number of rows processed by the matrix multiplication */ + unsigned int n0{1}; /**< Number of columns processed by the matrix multiplication */ + unsigned int k0{1}; /**< Number of partial accumulations performed by the matrix multiplication */ + unsigned int v0{1}; /**< Number of vertical blocks of size (m0xk0) stored on the same output row */ + unsigned int h0{1}; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */ + bool interleave_lhs{false}; /**< True if the v0 (m0xk0) blocks have to be interleaved in the output row */ + bool interleave_rhs{false}; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */ + bool transpose_rhs{false}; /**< True if the (k0xn0) block has to be transposed before been stored */ + bool export_cl_image{false}; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */ }; } // namespace mlgo } // namespace arm_compute -#endif // SRC_RUNTIME_CL_MLGO_COMMON_H \ No newline at end of file +#endif // SRC_RUNTIME_CL_MLGO_COMMON_H diff --git a/src/runtime/CL/mlgo/HeuristicTree.cpp b/src/runtime/CL/mlgo/HeuristicTree.cpp index 1c75cdc427..f7b706902b 100644 --- a/src/runtime/CL/mlgo/HeuristicTree.cpp +++ b/src/runtime/CL/mlgo/HeuristicTree.cpp @@ -22,6 +22,7 @@ * SOFTWARE. */ #include "src/runtime/CL/mlgo/HeuristicTree.h" + #include "arm_compute/core/Log.h" #include "support/Cast.h" @@ -40,27 +41,23 @@ bool evaluate(GEMMShape shape, Condition cond) // PRE: all features and ConditionalOps are valid constexpr float eps = 0.0001f; // Calculate all secondary features - std::vector> cond_values - { - { "m", static_cast(shape.m) }, - { "n", static_cast(shape.n) }, - { "k", static_cast(shape.k) }, - { "b", static_cast(shape.b) }, - { "r_mn", static_cast(shape.m) / shape.n }, - { "r_mk", static_cast(shape.m) / shape.k }, - { "r_nk", static_cast(shape.n) / shape.k }, - { "r_mnk", static_cast(shape.m) / (static_cast(shape.n) / shape.k) }, - { "workload", (static_cast(shape.m) * shape.n * shape.b) / 20.0 } - }; - auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(), - [&cond](decltype(*cond_values.begin()) it) - { - return it.first == cond.feature; - }); + std::vector> cond_values{ + {"m", static_cast(shape.m)}, + {"n", static_cast(shape.n)}, + {"k", static_cast(shape.k)}, + {"b", static_cast(shape.b)}, + {"r_mn", static_cast(shape.m) / shape.n}, + {"r_mk", static_cast(shape.m) / shape.k}, + {"r_nk", static_cast(shape.n) / shape.k}, + {"r_mnk", static_cast(shape.m) / (static_cast(shape.n) / shape.k)}, + {"workload", (static_cast(shape.m) * shape.n * shape.b) / 20.0}}; + auto cond_value_pair_it = + std::find_if(cond_values.begin(), cond_values.end(), + [&cond](decltype(*cond_values.begin()) it) { return it.first == cond.feature; }); ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end()); const float cond_value = cond_value_pair_it->second; - switch(cond.op) + switch (cond.op) { case ConditionalOp::LT: { @@ -92,13 +89,12 @@ constexpr size_t HeuristicTree::_max_num_nodes; constexpr size_t HeuristicTree::_max_query_depth; constexpr HeuristicTree::NodeID HeuristicTree::_root; -HeuristicTree::HeuristicTree() - : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32) +HeuristicTree::HeuristicTree() : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32) { } HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type) - : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{} + : _id{id}, _heuristic_type{h_type}, _ip_target{ip_target}, _data_type{data_type}, _tree{} { } @@ -108,16 +104,17 @@ std::pair HeuristicTree::query(GEMMShape shape) const // Root ID = 0; auto cur_node = _tree.at(_root).get(); size_t depth = 0; - while(cur_node->type() != NodeType::Leaf) + while (cur_node->type() != NodeType::Leaf) { - if(depth > _max_query_depth) + if (depth > _max_query_depth) { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", + _max_query_depth); return std::make_pair(false, T{}); } ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType"); auto br_node = utils::cast::polymorphic_downcast(cur_node); - if(evaluate(shape, br_node->condition)) + if (evaluate(shape, br_node->condition)) { cur_node = _tree.at(br_node->true_node).get(); } @@ -135,12 +132,12 @@ std::pair HeuristicTree::query(GEMMShape shape) const template bool HeuristicTree::add_leaf(NodeID id, T val) { - if(_tree.size() >= _max_num_nodes) + if (_tree.size() >= _max_num_nodes) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); return false; } - if(_tree.find(id) != _tree.end()) + if (_tree.find(id) != _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); return false; @@ -151,28 +148,23 @@ bool HeuristicTree::add_leaf(NodeID id, T val) bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node) { - if(_tree.size() >= _max_num_nodes) + if (_tree.size() >= _max_num_nodes) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); return false; } - const std::set supported_features = - { - "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload" - }; - const auto orig_feature = cond.feature; - std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c) - { - return std::tolower(c); - }); - if(supported_features.find(cond.feature) == supported_features.end()) + const std::set supported_features = {"m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"}; + const auto orig_feature = cond.feature; + std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), + [](char c) { return std::tolower(c); }); + if (supported_features.find(cond.feature) == supported_features.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str()); return false; } - if(_tree.find(id) != _tree.end()) + if (_tree.find(id) != _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); return false; @@ -184,32 +176,32 @@ bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID bool HeuristicTree::check_if_structurally_correct() const { std::set visited; - std::deque to_visit{ _root }; + std::deque to_visit{_root}; - while(!to_visit.empty()) + while (!to_visit.empty()) { auto id = to_visit.front(); to_visit.pop_front(); - if(_tree.find(id) == _tree.end()) + if (_tree.find(id) == _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id); return false; } auto not_seen_before = visited.insert(id); - if(!not_seen_before.second) + if (!not_seen_before.second) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops"); return false; } auto cur_node = _tree.at(id).get(); - if(cur_node->type() == NodeType::Branch) + if (cur_node->type() == NodeType::Branch) { auto br_node = utils::cast::polymorphic_downcast(cur_node); to_visit.push_back(br_node->true_node); to_visit.push_back(br_node->false_node); } } - if(visited.size() != _tree.size()) + if (visited.size() != _tree.size()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes"); return false; @@ -219,12 +211,12 @@ bool HeuristicTree::check_if_structurally_correct() const bool HeuristicTree::check() { - if(_tree.empty()) + if (_tree.empty()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered"); return false; } - if(_tree.find(_root) == _tree.end()) + if (_tree.find(_root) == _tree.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root); return false; @@ -237,7 +229,8 @@ template std::pair HeuristicTree::query(GEMMShape shap /** Explicit template instantiation @relates HeuristicTree */ template std::pair HeuristicTree::query(GEMMShape shape) const; /** Explicit template instantiation @relates HeuristicTree */ -template std::pair HeuristicTree::query(GEMMShape shape) const; +template std::pair +HeuristicTree::query(GEMMShape shape) const; /** Explicit template instantiation @relates HeuristicTree */ template std::pair HeuristicTree::query(GEMMShape shape) const; diff --git a/src/runtime/CL/mlgo/HeuristicTree.h b/src/runtime/CL/mlgo/HeuristicTree.h index d5c7de2215..a4f8c116b9 100644 --- a/src/runtime/CL/mlgo/HeuristicTree.h +++ b/src/runtime/CL/mlgo/HeuristicTree.h @@ -25,6 +25,7 @@ #define SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H #include "arm_compute/core/Types.h" + #include "src/runtime/CL/mlgo/Common.h" #include @@ -84,7 +85,7 @@ public: struct BranchNode : public Node { BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node) - : id{ id }, condition{ cond }, true_node{ t_node }, false_node{ f_node } + : id{id}, condition{cond}, true_node{t_node}, false_node{f_node} { } NodeType type() const override @@ -100,8 +101,7 @@ public: template struct LeafNode : public Node { - LeafNode(NodeID id, T val) - : id{ id }, value{ val } + LeafNode(NodeID id, T val) : id{id}, value{val} { } NodeType type() const override @@ -177,22 +177,22 @@ public: bool check(); private: - static constexpr size_t _max_query_depth{ 1000 }; // Maximum depth of query - static constexpr size_t _max_num_nodes{ 100000 }; // Maximum number of nodes contained by the tree - static constexpr NodeID _root{ 0 }; // Root tree ID + static constexpr size_t _max_query_depth{1000}; // Maximum depth of query + static constexpr size_t _max_num_nodes{100000}; // Maximum number of nodes contained by the tree + static constexpr NodeID _root{0}; // Root tree ID private: bool check_if_structurally_correct() const; private: - TreeID _id; /**< Heuristic tree ID */ - HeuristicType _heuristic_type; /**< Heuristic type */ - std::string _ip_target; /**< IP target associated with the tree */ - DataType _data_type; /**< Data type associated with the tree */ - std::map> _tree; /**< Tree representation */ + TreeID _id; /**< Heuristic tree ID */ + HeuristicType _heuristic_type; /**< Heuristic type */ + std::string _ip_target; /**< IP target associated with the tree */ + DataType _data_type; /**< Data type associated with the tree */ + std::map> _tree; /**< Tree representation */ }; } // namespace mlgo } // namespace arm_compute -#endif //SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H \ No newline at end of file +#endif //SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.cpp b/src/runtime/CL/mlgo/MLGOHeuristics.cpp index 80f3bb85e9..aed46cd80f 100644 --- a/src/runtime/CL/mlgo/MLGOHeuristics.cpp +++ b/src/runtime/CL/mlgo/MLGOHeuristics.cpp @@ -24,6 +24,7 @@ #include "src/runtime/CL/mlgo/MLGOHeuristics.h" #include "arm_compute/core/Log.h" + #include "src/runtime/CL/mlgo/MLGOParser.h" #include "src/runtime/CL/mlgo/Utils.h" @@ -39,19 +40,19 @@ bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs) } bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs) { - return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs, - rhs.export_cl_image); + return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == + std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image); } bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs) { - return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0, - rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image); + return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, + lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0, rhs.interleave_lhs, + rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image); } constexpr size_t MLGOHeuristics::_max_num_trees; -MLGOHeuristics::MLGOHeuristics() - : _indices{}, _trees{}, _tree_valid{}, _valid{ false } +MLGOHeuristics::MLGOHeuristics() : _indices{}, _trees{}, _tree_valid{}, _valid{false} { } @@ -59,71 +60,74 @@ std::pair MLGOHeuristics::query_gemm_type(const Query &query) co { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm type. %s.", to_string(query).c_str()); const auto invalid = GEMMType::RESHAPED; - if(!_valid) + if (!_valid) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); - return { false, invalid }; + return {false, invalid}; } auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type); - GEMMShape shape_query{ query.m, query.n, query.k, query.b }; - if(_trees.find(index) == _trees.end()) + GEMMShape shape_query{query.m, query.n, query.k, query.b}; + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); - return { false, invalid }; + return {false, invalid}; } return _trees.at(index).query(shape_query); } std::pair MLGOHeuristics::query_gemm_config_native(const Query &query) const { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", to_string(query).c_str()); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config native. %s.", + to_string(query).c_str()); const auto invalid = GEMMConfigNative{}; - if(!_valid) + if (!_valid) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); - return { false, invalid }; + return {false, invalid}; } auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type); - GEMMShape shape_query{ query.m, query.n, query.k, query.b }; - if(_trees.find(index) == _trees.end()) + GEMMShape shape_query{query.m, query.n, query.k, query.b}; + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); - return { false, invalid }; + return {false, invalid}; } return _trees.at(index).query(shape_query); } std::pair MLGOHeuristics::query_gemm_config_reshaped_only_rhs(const Query &query) const { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", to_string(query).c_str()); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped only rhs. %s.", + to_string(query).c_str()); const auto invalid = GEMMConfigReshapedOnlyRHS{}; - if(!_valid) + if (!_valid) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); - return { false, invalid }; + return {false, invalid}; } auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type); - GEMMShape shape_query{ query.m, query.n, query.k, query.b }; - if(_trees.find(index) == _trees.end()) + GEMMShape shape_query{query.m, query.n, query.k, query.b}; + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); - return { false, invalid }; + return {false, invalid}; } return _trees.at(index).query(shape_query); } std::pair MLGOHeuristics::query_gemm_config_reshaped(const Query &query) const { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", to_string(query).c_str()); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics querying gemm config reshaped. %s.", + to_string(query).c_str()); const auto invalid = GEMMConfigReshaped{}; - if(!_valid) + if (!_valid) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); - return { false, invalid }; + return {false, invalid}; } auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type); - GEMMShape shape_query{ query.m, query.n, query.k, query.b }; - if(_trees.find(index) == _trees.end()) + GEMMShape shape_query{query.m, query.n, query.k, query.b}; + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); - return { false, invalid }; + return {false, invalid}; } return _trees.at(index).query(shape_query); } @@ -131,14 +135,14 @@ std::pair MLGOHeuristics::query_gemm_config_reshaped(c bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id) { bool status; - HeuristicTree *tree{ nullptr }; + HeuristicTree *tree{nullptr}; std::tie(status, tree) = get_heuristic_tree(id); - if(!status) + if (!status) { return status; } status = tree->check(); - if(!status) + if (!status) { return status; } @@ -149,14 +153,12 @@ bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id) bool MLGOHeuristics::check_all() const { // Tree validities are already checked and cached. - bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v) - { - return !v.second; - }) - == _tree_valid.end(); - if(!all_trees_are_checked) + bool all_trees_are_checked = + std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v) { return !v.second; }) == _tree_valid.end(); + if (!all_trees_are_checked) { - ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo"); + ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each " + "tree is completed. This could also indicate there are no trees in the dotmlgo"); return false; } @@ -167,14 +169,14 @@ bool MLGOHeuristics::check_all() const std::pair MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id) { - if(_indices.find(id) == _indices.end()) + if (_indices.find(id) == _indices.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id); return std::make_pair(false, nullptr); } const auto index = _indices[id]; - if(_trees.find(index) == _trees.end()) + if (_trees.find(index) == _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); return std::make_pair(false, nullptr); @@ -186,7 +188,7 @@ std::pair MLGOHeuristics::get_heuristic_tree(HeuristicTre bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t) { - if(_indices.size() >= _max_num_trees) + if (_indices.size() >= _max_num_trees) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees); return false; @@ -194,7 +196,7 @@ bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t) // PRE: correctness of t is guaranteed by the tree construction process // Ensure unique id const auto id = t.id(); - if(_indices.find(id) != _indices.end()) + if (_indices.find(id) != _indices.end()) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id); return false; @@ -202,7 +204,7 @@ bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t) // Ensure unique index const auto index = t.index(); - if(_trees.find(index) != _trees.end()) + if (_trees.find(index) != _trees.end()) { ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists"); return false; @@ -219,9 +221,10 @@ bool MLGOHeuristics::reload_from_file(const std::string &filename) std::ifstream fs; fs.exceptions(std::ifstream::badbit); fs.open(filename, std::ios::in); - if(!fs.is_open()) + if (!fs.is_open()) { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str()); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", + filename.c_str()); return _valid = false; } return reload_from_stream(fs); @@ -230,7 +233,7 @@ bool MLGOHeuristics::reload_from_file(const std::string &filename) bool MLGOHeuristics::reload_from_stream(std::istream &in) { auto parsed = parser::parse_mlgo(in); - if(!parsed.first) + if (!parsed.first) { ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead"); return _valid = false; @@ -241,4 +244,4 @@ bool MLGOHeuristics::reload_from_stream(std::istream &in) } } // namespace mlgo -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.h b/src/runtime/CL/mlgo/MLGOHeuristics.h index aa21225959..6a491c5503 100644 --- a/src/runtime/CL/mlgo/MLGOHeuristics.h +++ b/src/runtime/CL/mlgo/MLGOHeuristics.h @@ -135,16 +135,16 @@ public: bool check_all() const; private: - static constexpr size_t _max_num_trees{ 100 }; /**< Max number of trees that can be added*/ + static constexpr size_t _max_num_trees{100}; /**< Max number of trees that can be added*/ private: // There exists a one-to-one mappipng between TreeID and Index, either can be used to identify a @ref HeuristicTree std::map _indices; /**< A mapping from TreeID to Index */ std::map _trees; /**< A mapping from Index to HeuristicTree */ std::map _tree_valid; /**< Result cache of the tree validity checks */ - bool _valid; /**< Overall validity */ + bool _valid; /**< Overall validity */ }; } // namespace mlgo } // namespace arm_compute -#endif //SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H \ No newline at end of file +#endif //SRC_RUNTIME_CL_MLGO_MLGO_HEURISTICS_H diff --git a/src/runtime/CL/mlgo/MLGOParser.cpp b/src/runtime/CL/mlgo/MLGOParser.cpp index 625739e450..893daf2ed9 100644 --- a/src/runtime/CL/mlgo/MLGOParser.cpp +++ b/src/runtime/CL/mlgo/MLGOParser.cpp @@ -22,19 +22,21 @@ * SOFTWARE. */ #include "src/runtime/CL/mlgo/MLGOParser.h" + #include "arm_compute/core/Log.h" + #include "src/runtime/CL/mlgo/Utils.h" #include #define CHECK(parser_expr, valid_var) \ (parser_expr); \ - if(!valid_var) \ + if (!valid_var) \ return; #define CHECK_DEFAULT(parser_expr, valid_var, default_val) \ (parser_expr); \ - if(!valid_var) \ + if (!valid_var) \ return default_val; #ifdef ARM_COMPUTE_LOGGING_ENABLED @@ -53,8 +55,7 @@ valid_var = false; \ return default_val; -#define LOG_TOKEN_POS(tokens, pos_var) \ - const auto pos_var = tokens.current_pos(); +#define LOG_TOKEN_POS(tokens, pos_var) const auto pos_var = tokens.current_pos(); #else // ARM_COMPUTE_LOGGING_ENABLED @@ -73,19 +74,12 @@ namespace { void ltrim(std::string &str) { - str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char ch) - { - return !std::isspace(ch); - })); + str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char ch) { return !std::isspace(ch); })); } void rtrim(std::string &str) { - str.erase(std::find_if(str.rbegin(), str.rend(), [](char ch) - { - return !std::isspace(ch); - }).base(), - str.end()); + str.erase(std::find_if(str.rbegin(), str.rend(), [](char ch) { return !std::isspace(ch); }).base(), str.end()); } void trim(std::string &str) @@ -109,7 +103,7 @@ enum class ComparatorType }; TokenStream::TokenStream(std::istream &s, const std::string &delims) - : _delims{ delims }, _istream{ s }, _tokens{}, _lookahead_pos{} + : _delims{delims}, _istream{s}, _tokens{}, _lookahead_pos{} { read(); } @@ -125,7 +119,7 @@ Token TokenStream::take() ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty"); Token t = _tokens.front(); _tokens.pop_front(); - if(_tokens.empty()) + if (_tokens.empty()) { read(); } @@ -136,7 +130,7 @@ Token TokenStream::peek(size_t i) ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty"); ARM_COMPUTE_ERROR_ON_MSG(i >= max_look_ahead, "TokenStream: Exceeding max look ahead"); // NOTE: If i exceeds the stream (_istream.eof()), read() automatically appends a End token at the end - while(_istream && _tokens.size() <= i) + while (_istream && _tokens.size() <= i) { read(); } @@ -146,7 +140,7 @@ Token TokenStream::peek(size_t i) void advance(CharPosition &pos, char ch) { - if(ch == '\n') + if (ch == '\n') { pos.ln += 1; pos.col = 0; @@ -167,17 +161,16 @@ void TokenStream::read() do { // Reached eof - if(!_istream.get(ch)) + if (!_istream.get(ch)) { - if(!reached_end()) + if (!reached_end()) { _tokens.emplace_back(TokenType::End, "", _lookahead_pos); } return; } advance(_lookahead_pos, ch); - } - while(std::isspace(ch) || is_delim(ch)); + } while (std::isspace(ch) || is_delim(ch)); // Read chars until we hit a delim or eof auto orig_pos = _lookahead_pos; auto tok = recognize_tok(ch); @@ -190,41 +183,41 @@ void TokenStream::read() Token TokenStream::recognize_tok(char ch) { - if(ch == '[') + if (ch == '[') { - return Token{ TokenType::L_List, "", _lookahead_pos }; + return Token{TokenType::L_List, "", _lookahead_pos}; } - else if(ch == ']') + else if (ch == ']') { - return Token{ TokenType::R_List, "", _lookahead_pos }; + return Token{TokenType::R_List, "", _lookahead_pos}; } - else if(ch == '.') + else if (ch == '.') { - return float_after_dp_st(std::string{ ch }); + return float_after_dp_st(std::string{ch}); } - else if(std::isdigit(ch)) + else if (std::isdigit(ch)) { - return num_st(std::string{ ch }); + return num_st(std::string{ch}); } else { - return text_st(std::string{ ch }); + return text_st(std::string{ch}); } } Token TokenStream::num_st(std::string value) { char ch{}; - while(_istream.get(ch)) + while (_istream.get(ch)) { advance(_lookahead_pos, ch); - if(ch == '.') + if (ch == '.') { return float_after_dp_st(value + ch); } - else if(!std::isdigit(ch)) + else if (!std::isdigit(ch)) { - if(!is_delim(ch) && !std::isspace(ch)) + if (!is_delim(ch) && !std::isspace(ch)) { rewind(_lookahead_pos); _istream.unget(); @@ -233,18 +226,18 @@ Token TokenStream::num_st(std::string value) } value += ch; } - return Token{ TokenType::Int, value, _lookahead_pos }; + return Token{TokenType::Int, value, _lookahead_pos}; } Token TokenStream::float_after_dp_st(std::string value) { char ch{}; - while(_istream.get(ch)) + while (_istream.get(ch)) { advance(_lookahead_pos, ch); - if(!std::isdigit(ch)) + if (!std::isdigit(ch)) { - if(!is_delim(ch) && !std::isspace(ch)) + if (!is_delim(ch) && !std::isspace(ch)) { rewind(_lookahead_pos); _istream.unget(); @@ -253,20 +246,20 @@ Token TokenStream::float_after_dp_st(std::string value) } value += ch; } - return Token{ TokenType::Float, value, _lookahead_pos }; + return Token{TokenType::Float, value, _lookahead_pos}; } Token TokenStream::text_st(std::string value) { char ch{}; - while(_istream.get(ch)) + while (_istream.get(ch)) { advance(_lookahead_pos, ch); - if(is_delim(ch)) + if (is_delim(ch)) { break; } - if(ch == '[' || ch == ']') + if (ch == '[' || ch == ']') { rewind(_lookahead_pos); _istream.unget(); @@ -274,7 +267,7 @@ Token TokenStream::text_st(std::string value) } value += ch; } - return Token{ TokenType::Text, value, _lookahead_pos }; + return Token{TokenType::Text, value, _lookahead_pos}; } bool TokenStream::reached_end() const @@ -291,7 +284,7 @@ void end(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); auto tok = in.take(); - if(tok.type != TokenType::End) + if (tok.type != TokenType::End) { FAIL_WITH_MSG(valid, pos, "Unexpected token at the end of stream"); } @@ -301,7 +294,7 @@ bool bool_val(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); auto tok = in.take(); - if(tok.type != TokenType::Int) + if (tok.type != TokenType::Int) { FAIL_WITH_MSG_DEFAULT(valid, false, pos, "Expect bool or int token"); } @@ -314,7 +307,7 @@ int int_val(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); auto tok = in.take(); - if(tok.type != TokenType::Int) + if (tok.type != TokenType::Int) { FAIL_WITH_MSG_DEFAULT(valid, -1, pos, "Expect int token"); } @@ -327,7 +320,7 @@ unsigned int uint_val(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); int val = CHECK_DEFAULT(int_val(in, valid), valid, 0); - if(val < 0) + if (val < 0) { FAIL_WITH_MSG_DEFAULT(valid, 0, pos, "Expect unsigned int token"); } @@ -338,7 +331,7 @@ float float_val(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); auto tok = in.take(); - if(tok.type != TokenType::Float) + if (tok.type != TokenType::Float) { FAIL_WITH_MSG_DEFAULT(valid, 0.f, pos, "Expect float token"); } @@ -351,7 +344,7 @@ std::string text_val(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); auto tok = in.take(); - if(tok.type != TokenType::Text || tok.value.empty()) + if (tok.type != TokenType::Text || tok.value.empty()) { FAIL_WITH_MSG_DEFAULT(valid, "", pos, "Expect a non-empty text token"); } @@ -361,9 +354,9 @@ std::string text_val(TokenStream &in, bool &valid) bool accept_text(TokenStream &in, const std::string &c_str, bool take = true) { auto tok = in.peek(); - if(tok.type == TokenType::Text && tok.value == c_str) + if (tok.type == TokenType::Text && tok.value == c_str) { - if(take) + if (take) { in.take(); } @@ -375,7 +368,7 @@ bool accept_text(TokenStream &in, const std::string &c_str, bool take = true) void expect_text(TokenStream &in, const std::string &str, bool &valid) { LOG_TOKEN_POS(in, pos); - if(!accept_text(in, str)) + if (!accept_text(in, str)) { FAIL_WITH_MSG(valid, pos, std::string("Expect text token: ") + str); } @@ -384,7 +377,7 @@ void expect_text(TokenStream &in, const std::string &str, bool &valid) bool accept_l_list(TokenStream &in) { auto tok = in.peek(); - if(tok.type == TokenType::L_List) + if (tok.type == TokenType::L_List) { in.take(); return true; @@ -395,7 +388,7 @@ bool accept_l_list(TokenStream &in) void expect_l_list(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); - if(!accept_l_list(in)) + if (!accept_l_list(in)) { FAIL_WITH_MSG(valid, pos, "Expect '['"); } @@ -404,7 +397,7 @@ void expect_l_list(TokenStream &in, bool &valid) bool accept_r_list(TokenStream &in) { auto tok = in.peek(); - if(tok.type == TokenType::R_List) + if (tok.type == TokenType::R_List) { in.take(); return true; @@ -415,7 +408,7 @@ bool accept_r_list(TokenStream &in) void expect_r_list(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); - if(!accept_r_list(in)) + if (!accept_r_list(in)) { FAIL_WITH_MSG(valid, pos, "Expect ']'"); } @@ -424,23 +417,23 @@ void expect_r_list(TokenStream &in, bool &valid) ConditionalOp conditional_op(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); - if(accept_text(in, "<=")) + if (accept_text(in, "<=")) { return ConditionalOp::LE; } - else if(accept_text(in, ">=")) + else if (accept_text(in, ">=")) { return ConditionalOp::GE; } - else if(accept_text(in, "==")) + else if (accept_text(in, "==")) { return ConditionalOp::EQ; } - else if(accept_text(in, "<")) + else if (accept_text(in, "<")) { return ConditionalOp::LT; } - else if(accept_text(in, ">")) + else if (accept_text(in, ">")) { return ConditionalOp::GT; } @@ -464,11 +457,11 @@ void ip_type(TokenStream &in, bool &valid) { CHECK(expect_text(in, "ip-type", valid), valid); LOG_TOKEN_POS(in, pos); - if(accept_text(in, "gpu")) + if (accept_text(in, "gpu")) { ; } - else if(accept_text(in, "cpu")) + else if (accept_text(in, "cpu")) { ; } @@ -489,15 +482,15 @@ void header(TokenStream &in, bool &valid) DataType data_type(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); - if(accept_text(in, "f16")) + if (accept_text(in, "f16")) { return DataType::F16; } - else if(accept_text(in, "f32")) + else if (accept_text(in, "f32")) { return DataType::F32; } - else if(accept_text(in, "qasymm8")) + else if (accept_text(in, "qasymm8")) { return DataType::QASYMM8; } @@ -510,15 +503,15 @@ DataType data_type(TokenStream &in, bool &valid) ComparatorType comparator_type(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); - if(accept_text(in, "var")) + if (accept_text(in, "var")) { return ComparatorType::Var; } - else if(accept_text(in, "num")) + else if (accept_text(in, "num")) { return ComparatorType::Num; } - else if(accept_text(in, "enum")) + else if (accept_text(in, "enum")) { return ComparatorType::Enum; } @@ -531,19 +524,19 @@ ComparatorType comparator_type(TokenStream &in, bool &valid) HeuristicType heuristic_type(TokenStream &in, bool &valid, bool take = true) { LOG_TOKEN_POS(in, pos); - if(accept_text(in, "gemm-type", take)) + if (accept_text(in, "gemm-type", take)) { return HeuristicType::GEMM_Type; } - else if(accept_text(in, "gemm-config-native", take)) + else if (accept_text(in, "gemm-config-native", take)) { return HeuristicType::GEMM_Config_Native; } - else if(accept_text(in, "gemm-config-reshaped-only-rhs", take)) + else if (accept_text(in, "gemm-config-reshaped-only-rhs", take)) { return HeuristicType::GEMM_Config_Reshaped_Only_RHS; } - else if(accept_text(in, "gemm-config-reshaped", take)) + else if (accept_text(in, "gemm-config-reshaped", take)) { return HeuristicType::GEMM_Config_Reshaped; } @@ -557,7 +550,7 @@ void expect_heuristic_type(TokenStream &in, HeuristicType expected_ht, bool &val { LOG_TOKEN_POS(in, pos); auto ht = CHECK(heuristic_type(in, valid, false), valid); - if(ht != expected_ht) + if (ht != expected_ht) { FAIL_WITH_MSG(valid, pos, "Unexpected heuristic type"); } @@ -567,15 +560,15 @@ void expect_heuristic_type(TokenStream &in, HeuristicType expected_ht, bool &val GEMMType gemm_type(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); - if(accept_text(in, "native")) + if (accept_text(in, "native")) { return GEMMType::NATIVE; } - else if(accept_text(in, "reshaped-only-rhs")) + else if (accept_text(in, "reshaped-only-rhs")) { return GEMMType::RESHAPED_ONLY_RHS; } - else if(accept_text(in, "reshaped")) + else if (accept_text(in, "reshaped")) { return GEMMType::RESHAPED; } @@ -593,7 +586,7 @@ GEMMConfigNative gemm_config_native(TokenStream &in, bool &valid) const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val); - return GEMMConfigNative{ m0, n0, k0 }; + return GEMMConfigNative{m0, n0, k0}; } GEMMConfigReshapedOnlyRHS gemm_config_reshaped_only_rhs(TokenStream &in, bool &valid) @@ -608,7 +601,7 @@ GEMMConfigReshapedOnlyRHS gemm_config_reshaped_only_rhs(TokenStream &in, bool &v const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val); - return GEMMConfigReshapedOnlyRHS{ m0, n0, k0, h0, ir, tr, ex }; + return GEMMConfigReshapedOnlyRHS{m0, n0, k0, h0, ir, tr, ex}; } GEMMConfigReshaped gemm_config_reshaped(TokenStream &in, bool &valid) @@ -625,17 +618,17 @@ GEMMConfigReshaped gemm_config_reshaped(TokenStream &in, bool &valid) const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val); - return GEMMConfigReshaped{ m0, n0, k0, v0, h0, il, ir, tr, ex }; + return GEMMConfigReshaped{m0, n0, k0, v0, h0, il, ir, tr, ex}; } void gpu_priority(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); - if(accept_text(in, "best-performance")) + if (accept_text(in, "best-performance")) { ; } - else if(accept_text(in, "best-memory-usage")) + else if (accept_text(in, "best-memory-usage")) { ; } @@ -648,11 +641,11 @@ void gpu_priority(TokenStream &in, bool &valid) void gpu_behavior(TokenStream &in, bool &valid) { LOG_TOKEN_POS(in, pos); - if(accept_text(in, "static")) + if (accept_text(in, "static")) { ; } - else if(accept_text(in, "dynamic")) + else if (accept_text(in, "dynamic")) { ; } @@ -665,7 +658,7 @@ void gpu_behavior(TokenStream &in, bool &valid) void free_vars(TokenStream &in, bool &valid) { CHECK(expect_l_list(in, valid), valid); - while(!accept_r_list(in)) + while (!accept_r_list(in)) { CHECK(text_val(in, valid), valid); } @@ -688,7 +681,7 @@ void heuristics_table_entry(TokenStream &in, MLGOHeuristics &h, bool &valid) void heuristics_table(TokenStream &in, MLGOHeuristics &h, bool &valid) { CHECK(expect_text(in, "", valid), valid); - while(!accept_text(in, "")) + while (!accept_text(in, "")) { CHECK(heuristics_table_entry(in, h, valid), valid); } @@ -705,11 +698,12 @@ Condition condition(TokenStream &in, bool &valid) const auto c_o = CHECK_DEFAULT(conditional_op(in, valid), valid, invalid_val); const auto r_t = CHECK_DEFAULT(comparator_type(in, valid), valid, invalid_val); const auto r_v = CHECK_DEFAULT(float_val(in, valid), valid, invalid_val); - if(l_t != ComparatorType::Var || r_t != ComparatorType::Num) + if (l_t != ComparatorType::Var || r_t != ComparatorType::Num) { - FAIL_WITH_MSG_DEFAULT(valid, invalid_val, pos, "Only accept LHS type to be Var (string) and RHS type to be Num (float)"); + FAIL_WITH_MSG_DEFAULT(valid, invalid_val, pos, + "Only accept LHS type to be Var (string) and RHS type to be Num (float)"); } - return Condition{ l_v, c_o, r_v }; + return Condition{l_v, c_o, r_v}; } void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid) @@ -717,13 +711,13 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid) CHECK(expect_text(in, "", valid), valid); - HeuristicTree *t = nullptr; - std::tie(valid, t) = CHECK(h.get_heuristic_tree(tree_id), valid); + HeuristicTree *t = nullptr; + std::tie(valid, t) = CHECK(h.get_heuristic_tree(tree_id), valid); const HeuristicType t_heuristic_type = std::get<0>(t->index()); - while(!accept_text(in, "")) + while (!accept_text(in, "")) { LOG_TOKEN_POS(in, pos); - if(accept_text(in, "b")) + if (accept_text(in, "b")) { // Branch node const auto id = CHECK(uint_val(in, valid), valid); @@ -732,7 +726,7 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid) const auto f_id = CHECK(uint_val(in, valid), valid); valid = CHECK(t->add_branch(id, cond, t_id, f_id), valid); } - else if(accept_text(in, "l")) + else if (accept_text(in, "l")) { // Leaf node const auto id = CHECK(uint_val(in, valid), valid); @@ -740,7 +734,7 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid) // heuristic table). For now it remains as a step for validation. LOG_TOKEN_POS(in, pos); CHECK(expect_heuristic_type(in, t_heuristic_type, valid), valid); - switch(t_heuristic_type) + switch (t_heuristic_type) { case HeuristicType::GEMM_Type: { @@ -786,7 +780,7 @@ MLGOHeuristics mlgo(TokenStream &in, bool &valid) MLGOHeuristics h; CHECK_DEFAULT(header(in, valid), valid, h); CHECK_DEFAULT(heuristics_table(in, h, valid), valid, h); - while(accept_text(in, " parse_mlgo(std::istream &in) #undef CHECK #undef CHECK_DEFAULT #undef FAIL_WITH_MSG -#undef FAIL_WITH_MSG_DEFAULT \ No newline at end of file +#undef FAIL_WITH_MSG_DEFAULT diff --git a/src/runtime/CL/mlgo/MLGOParser.h b/src/runtime/CL/mlgo/MLGOParser.h index 49d8b9c644..cffce8d6a1 100644 --- a/src/runtime/CL/mlgo/MLGOParser.h +++ b/src/runtime/CL/mlgo/MLGOParser.h @@ -98,15 +98,14 @@ struct CharPosition return ln == other.ln && col == other.col; } - size_t ln{ 0 }; - size_t col{ 0 }; + size_t ln{0}; + size_t col{0}; }; /** Token */ struct Token { - Token(TokenType t, std::string v, CharPosition pos) - : type{ t }, value{ v }, pos{ pos } + Token(TokenType t, std::string v, CharPosition pos) : type{t}, value{v}, pos{pos} { } @@ -196,4 +195,4 @@ std::pair parse_mlgo(std::istream &in); } // namespace parser } // namespace mlgo } // namespace arm_compute -#endif //SRC_RUNTIME_CL_MLGO_MLGO_PARSER_H \ No newline at end of file +#endif //SRC_RUNTIME_CL_MLGO_MLGO_PARSER_H diff --git a/src/runtime/CL/mlgo/Utils.cpp b/src/runtime/CL/mlgo/Utils.cpp index 81d418c28e..c7e0100b3c 100644 --- a/src/runtime/CL/mlgo/Utils.cpp +++ b/src/runtime/CL/mlgo/Utils.cpp @@ -43,40 +43,38 @@ inline std::string to_str(const T &val) std::ostream &operator<<(std::ostream &os, const GEMMConfigNative &config) { return os << "Native:{" - << "m0: " << config.m0 << ", " - << "n0: " << config.n0 << ", " - << "k0: " << config.k0 << ", " - << "}"; + << "m0: " << config.m0 << ", " + << "n0: " << config.n0 << ", " + << "k0: " << config.k0 << ", " + << "}"; } std::ostream &operator<<(std::ostream &os, const GEMMConfigReshapedOnlyRHS &config) { return os << "ReshapedOnlyRHS:{" - << "m0: " << config.m0 << ", " - << "n0: " << config.n0 << ", " - << "k0: " << config.k0 << ", " - << "h0: " << config.h0 << ", " - << "interleave_rhs: " << config.interleave_rhs << ", " - << "transpose_rhs: " << config.transpose_rhs << ", " - << "export_cl_image: " << config.export_cl_image - << "}"; + << "m0: " << config.m0 << ", " + << "n0: " << config.n0 << ", " + << "k0: " << config.k0 << ", " + << "h0: " << config.h0 << ", " + << "interleave_rhs: " << config.interleave_rhs << ", " + << "transpose_rhs: " << config.transpose_rhs << ", " + << "export_cl_image: " << config.export_cl_image << "}"; } std::ostream &operator<<(std::ostream &os, const GEMMConfigReshaped &config) { return os << "Reshaped:{" - << "m0: " << config.m0 << ", " - << "n0: " << config.n0 << ", " - << "k0: " << config.k0 << ", " - << "v0: " << config.v0 << ", " - << "h0: " << config.h0 << ", " - << "interleave_lhs: " << config.interleave_lhs << ", " - << "interleave_rhs: " << config.interleave_rhs << ", " - << "transpose_rhs: " << config.transpose_rhs << ", " - << "export_cl_image: " << config.export_cl_image - << "}"; + << "m0: " << config.m0 << ", " + << "n0: " << config.n0 << ", " + << "k0: " << config.k0 << ", " + << "v0: " << config.v0 << ", " + << "h0: " << config.h0 << ", " + << "interleave_lhs: " << config.interleave_lhs << ", " + << "interleave_rhs: " << config.interleave_rhs << ", " + << "transpose_rhs: " << config.transpose_rhs << ", " + << "export_cl_image: " << config.export_cl_image << "}"; } std::ostream &operator<<(std::ostream &os, HeuristicType ht) { - switch(ht) + switch (ht) { case HeuristicType::GEMM_Type: { @@ -103,7 +101,7 @@ std::ostream &operator<<(std::ostream &os, HeuristicType ht) } std::ostream &operator<<(std::ostream &os, DataType dt) { - switch(dt) + switch (dt) { case DataType::F32: { @@ -184,4 +182,4 @@ std::ostream &operator<<(std::ostream &os, const CharPosition &pos) } // namespace mlgo -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/src/runtime/CL/mlgo/Utils.h b/src/runtime/CL/mlgo/Utils.h index c634a887e9..73b537f476 100644 --- a/src/runtime/CL/mlgo/Utils.h +++ b/src/runtime/CL/mlgo/Utils.h @@ -43,10 +43,10 @@ std::ostream &operator<<(std::ostream &os, HeuristicType ht); std::ostream &operator<<(std::ostream &os, DataType dt); std::ostream &operator<<(std::ostream &os, const HeuristicTree::Index &index); std::ostream &operator<<(std::ostream &os, const Query &query); -std::string to_string(const GEMMConfigNative &config); -std::string to_string(const GEMMConfigReshapedOnlyRHS &config); -std::string to_string(const GEMMConfigReshaped &config); -std::string to_string(const Query &query); +std::string to_string(const GEMMConfigNative &config); +std::string to_string(const GEMMConfigReshapedOnlyRHS &config); +std::string to_string(const GEMMConfigReshaped &config); +std::string to_string(const Query &query); namespace parser { std::ostream &operator<<(std::ostream &os, const CharPosition &pos); @@ -54,4 +54,4 @@ std::ostream &operator<<(std::ostream &os, const CharPosition &pos); } // namespace mlgo } // namespace arm_compute -#endif //SRC_RUNTIME_CL_MLGO_UTILS_H \ No newline at end of file +#endif //SRC_RUNTIME_CL_MLGO_UTILS_H -- cgit v1.2.1