From 7061eb283969f9a020c08349454447564e4dd5b3 Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Fri, 8 Jan 2021 15:16:02 +0000 Subject: Implement MLGO module * Implement MLGOHeuristics which provides a query and a loading interface * Implement a top-down parser MLGOParser for parsing dotmlgo * Add validation tests for MLGOHeuristics Resolves COMPMID-3840, COMPMID-3841 Signed-off-by: SiCong Li Change-Id: Iae96d2779524b2dd83623d1a3a30ef57823ae084 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4941 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- Android.bp | 4 + SConscript | 1 + src/runtime/CL/mlgo/Common.h | 81 +++ src/runtime/CL/mlgo/HeuristicTree.cpp | 253 +++++++++ src/runtime/CL/mlgo/HeuristicTree.h | 198 +++++++ src/runtime/CL/mlgo/MLGOHeuristics.cpp | 242 +++++++++ src/runtime/CL/mlgo/MLGOHeuristics.h | 139 +++++ src/runtime/CL/mlgo/MLGOParser.cpp | 812 ++++++++++++++++++++++++++++ src/runtime/CL/mlgo/MLGOParser.h | 199 +++++++ src/runtime/CL/mlgo/Utils.cpp | 143 +++++ src/runtime/CL/mlgo/Utils.h | 50 ++ tests/SConscript | 2 +- tests/validation/CL/UNIT/MLGOHeuristics.cpp | 461 ++++++++++++++++ 13 files changed, 2584 insertions(+), 1 deletion(-) create mode 100644 src/runtime/CL/mlgo/Common.h create mode 100644 src/runtime/CL/mlgo/HeuristicTree.cpp create mode 100644 src/runtime/CL/mlgo/HeuristicTree.h create mode 100644 src/runtime/CL/mlgo/MLGOHeuristics.cpp create mode 100644 src/runtime/CL/mlgo/MLGOHeuristics.h create mode 100644 src/runtime/CL/mlgo/MLGOParser.cpp create mode 100644 src/runtime/CL/mlgo/MLGOParser.h create mode 100644 src/runtime/CL/mlgo/Utils.cpp create mode 100644 src/runtime/CL/mlgo/Utils.h create mode 100644 tests/validation/CL/UNIT/MLGOHeuristics.cpp diff --git a/Android.bp b/Android.bp index 663adaedfc..f9761e2352 100644 --- a/Android.bp +++ b/Android.bp @@ -611,6 +611,10 @@ cc_library_static { "src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp", "src/runtime/CL/gemm/CLGEMMDefaultTypeMidgard.cpp", "src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp", + "src/runtime/CL/mlgo/HeuristicTree.cpp", + "src/runtime/CL/mlgo/MLGOHeuristics.cpp", + "src/runtime/CL/mlgo/MLGOParser.cpp", + "src/runtime/CL/mlgo/Utils.cpp", "src/runtime/CL/tuners/BifrostTuner.cpp", "src/runtime/CL/tuners/CLTuningParametersList.cpp", "src/runtime/CL/tuners/MidgardTuner.cpp", diff --git a/SConscript b/SConscript index 7e09240768..307a0983d8 100644 --- a/SConscript +++ b/SConscript @@ -226,6 +226,7 @@ if env['opencl']: runtime_files += Glob('src/runtime/CL/tuners/*.cpp') runtime_files += Glob('src/runtime/gpu/cl/*.cpp') runtime_files += Glob('src/runtime/gpu/cl/operators/*.cpp') + runtime_files += Glob('src/runtime/CL/mlgo/*.cpp') graph_files += Glob('src/graph/backends/CL/*.cpp') diff --git a/src/runtime/CL/mlgo/Common.h b/src/runtime/CL/mlgo/Common.h new file mode 100644 index 0000000000..a2d3ec8241 --- /dev/null +++ b/src/runtime/CL/mlgo/Common.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_MLGO_COMMON_H +#define SRC_MLGO_COMMON_H + +#include "arm_compute/core/Types.h" +#include "arm_compute/runtime/CL/CLTypes.h" + +namespace arm_compute +{ +namespace mlgo +{ +/** Types of Heuristic (tree) */ +enum class HeuristicType +{ + GEMM_Type, /**< About the type of gemm */ + GEMM_Config_Native, /**< About the gemm config for native kernel */ + GEMM_Config_Reshaped_Only_RHS, /**< About the gemm config for reshaped only rhs kernel */ + GEMM_Config_Reshaped /**< About the gemm config for reshaped kernel */ +}; + +using GEMMType = CLGEMMKernelType; + +/** GEMM Configuration for Native kernel */ +struct GEMMConfigNative +{ + unsigned int m0; /**< Number of rows processed by the matrix multiplication */ + unsigned int n0; /**< Number of columns processed by the matrix multiplication */ + unsigned int k0; /**< Number of partial accumulations performed by the matrix multiplication */ +}; + +/** GEMM Configuration for Reshaped Only RHS kernel */ +struct GEMMConfigReshapedOnlyRHS +{ + unsigned int m0; /**< Number of rows processed by the matrix multiplication */ + unsigned int n0; /**< Number of columns processed by the matrix multiplication */ + unsigned int k0; /**< Number of partial accumulations performed by the matrix multiplication */ + unsigned int h0; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */ + bool interleave_rhs; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */ + bool transpose_rhs; /**< True if the (k0xn0) block has to be transposed before been stored */ + bool export_cl_image; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */ +}; + +/** GEMM Configuration for Reshaped kernel */ +struct GEMMConfigReshaped +{ + unsigned int m0; /**< Number of rows processed by the matrix multiplication */ + unsigned int n0; /**< Number of columns processed by the matrix multiplication */ + unsigned int k0; /**< Number of partial accumulations performed by the matrix multiplication */ + unsigned int v0; /**< Number of vertical blocks of size (m0xk0) stored on the same output row */ + unsigned int h0; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */ + bool interleave_lhs; /**< True if the v0 (m0xk0) blocks have to be interleaved in the output row */ + bool interleave_rhs; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */ + bool transpose_rhs; /**< True if the (k0xn0) block has to be transposed before been stored */ + bool export_cl_image; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */ +}; + +} // namespace mlgo +} // namespace arm_compute +#endif // SRC_MLGO_COMMON_H \ No newline at end of file diff --git a/src/runtime/CL/mlgo/HeuristicTree.cpp b/src/runtime/CL/mlgo/HeuristicTree.cpp new file mode 100644 index 0000000000..65219998cb --- /dev/null +++ b/src/runtime/CL/mlgo/HeuristicTree.cpp @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/CL/mlgo/HeuristicTree.h" +#include "arm_compute/core/Log.h" + +#include +#include +#include +namespace arm_compute +{ +namespace mlgo +{ +namespace +{ +bool evaluate(GEMMShape shape, Condition cond) +{ + // PRE: all features and ConditionalOps are valid + constexpr float eps = 0.0001f; + // Calculate all secondary features + std::vector> cond_values + { + { "m", static_cast(shape.m) }, + { "n", static_cast(shape.n) }, + { "k", static_cast(shape.k) }, + { "b", static_cast(shape.b) }, + { "r_mn", static_cast(shape.m) / shape.n }, + { "r_mk", static_cast(shape.m) / shape.k }, + { "r_nk", static_cast(shape.n) / shape.k }, + { "r_mnk", static_cast(shape.m) / (static_cast(shape.n) / shape.k) }, + { "workload", (static_cast(shape.m) * shape.n * shape.b) / 20.0 } + }; + auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(), + [&cond](decltype(*cond_values.begin()) it) + { + return it.first == cond.feature; + }); + + ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end()); + const float cond_value = cond_value_pair_it->second; + switch(cond.op) + { + case ConditionalOp::LT: + { + return cond_value < cond.threshold; + } + case ConditionalOp::LE: + { + return cond_value <= cond.threshold; + } + case ConditionalOp::GT: + { + return cond_value > cond.threshold; + } + case ConditionalOp::GE: + { + return cond_value >= cond.threshold; + } + case ConditionalOp::EQ: + default: + { + return std::abs(cond_value - cond.threshold) < eps; + } + } +} + +} // namespace + +constexpr size_t HeuristicTree::_max_num_nodes; +constexpr size_t HeuristicTree::_max_query_depth; +constexpr HeuristicTree::NodeID HeuristicTree::_root; + +HeuristicTree::HeuristicTree() + : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32) +{ +} + +HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type) + : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{} +{ +} + +template +std::pair HeuristicTree::query(GEMMShape shape) const +{ + // Root ID = 0; + auto cur_node = _tree.at(_root).get(); + size_t depth = 0; + while(cur_node->type() != NodeType::Leaf) + { + if(depth > _max_query_depth) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth); + return std::make_pair(false, T{}); + } + ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType"); + auto br_node = dynamic_cast(cur_node); + if(evaluate(shape, br_node->condition)) + { + cur_node = _tree.at(br_node->true_node).get(); + } + else + { + cur_node = _tree.at(br_node->false_node).get(); + } + ++depth; + } + ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType"); + auto l_node = dynamic_cast *>(cur_node); + return std::make_pair(true, l_node->value); +} + +template +bool HeuristicTree::add_leaf(NodeID id, T val) +{ + if(_tree.size() >= _max_num_nodes) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); + return false; + } + if(_tree.find(id) != _tree.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); + return false; + } + _tree[id] = std::make_unique>(id, val); + return true; +} + +bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node) +{ + if(_tree.size() >= _max_num_nodes) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes); + return false; + } + + const std::set supported_features = + { + "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload" + }; + const auto orig_feature = cond.feature; + std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c) + { + return std::tolower(c); + }); + if(supported_features.find(cond.feature) == supported_features.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str()); + return false; + } + + if(_tree.find(id) != _tree.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id); + return false; + } + _tree[id] = std::make_unique(id, cond, t_node, f_node); + return true; +} + +bool HeuristicTree::check_if_structurally_correct() const +{ + std::set visited; + std::deque to_visit{ _root }; + + while(!to_visit.empty()) + { + auto id = to_visit.front(); + to_visit.pop_front(); + if(_tree.find(id) == _tree.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id); + return false; + } + auto not_seen_before = visited.insert(id); + if(!not_seen_before.second) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops"); + return false; + } + auto cur_node = _tree.at(id).get(); + if(cur_node->type() == NodeType::Branch) + { + auto br_node = dynamic_cast(cur_node); + to_visit.push_back(br_node->true_node); + to_visit.push_back(br_node->false_node); + } + } + if(visited.size() != _tree.size()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes"); + return false; + } + return true; +} + +bool HeuristicTree::check() +{ + if(_tree.empty()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered"); + return false; + } + if(_tree.find(_root) == _tree.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root); + return false; + } + return check_if_structurally_correct(); +} + +/** Explicit template instantiation @relates HeuristicTree */ +template std::pair HeuristicTree::query(GEMMShape shape) const; +/** Explicit template instantiation @relates HeuristicTree */ +template std::pair HeuristicTree::query(GEMMShape shape) const; +/** Explicit template instantiation @relates HeuristicTree */ +template std::pair HeuristicTree::query(GEMMShape shape) const; +/** Explicit template instantiation @relates HeuristicTree */ +template std::pair HeuristicTree::query(GEMMShape shape) const; + +/** Explicit template instantiation @relates HeuristicTree */ +template bool HeuristicTree::add_leaf(NodeID id, GEMMType val); +/** Explicit template instantiation @relates HeuristicTree */ +template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val); +/** Explicit template instantiation @relates HeuristicTree */ +template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val); +/** Explicit template instantiation @relates HeuristicTree */ +template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val); + +} // namespace mlgo + +} // namespace arm_compute diff --git a/src/runtime/CL/mlgo/HeuristicTree.h b/src/runtime/CL/mlgo/HeuristicTree.h new file mode 100644 index 0000000000..64d79ffaa1 --- /dev/null +++ b/src/runtime/CL/mlgo/HeuristicTree.h @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_MLGO_HEURISTICTREE_H +#define SRC_MLGO_HEURISTICTREE_H + +#include "arm_compute/core/Types.h" +#include "src/runtime/CL/mlgo/Common.h" + +#include +#include +#include +#include + +namespace arm_compute +{ +namespace mlgo +{ +/** Conditional ops */ +enum class ConditionalOp +{ + EQ, /**< Equal */ + LT, /**< Less than */ + LE, /**< Less than or equal to */ + GT, /**< Greater than */ + GE, /**< Greater than or equal to */ +}; + +/** A branch condition expression evaluating: feature op threshold */ +struct Condition +{ + std::string feature; /**< Feature name */ + ConditionalOp op; /**< Condtional op */ + float threshold; /**< Threshold value */ +}; + +/** GEMM Shape used for query */ +struct GEMMShape +{ + unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ + unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ + unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ + unsigned int b; /**< Batch size */ +}; + +/** A binary decision tree based heuristic */ +class HeuristicTree +{ +public: + using NodeID = size_t; + using TreeID = size_t; + using Index = std::tuple; + enum class NodeType + { + Branch, + Leaf + }; + struct Node + { + virtual NodeType type() const = 0; + virtual ~Node() = default; + }; + + struct BranchNode : public Node + { + BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node) + : id{ id }, condition{ cond }, true_node{ t_node }, false_node{ f_node } + { + } + NodeType type() const override + { + return NodeType::Branch; + } + NodeID id; + Condition condition; + NodeID true_node; + NodeID false_node; + }; + + template + struct LeafNode : public Node + { + LeafNode(NodeID id, T val) + : id{ id }, value{ val } + { + } + NodeType type() const override + { + return NodeType::Leaf; + } + NodeID id; + T value; + }; + +public: + /** Constructor */ + HeuristicTree(); + /** Constructor */ + HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type); + // Since the HeuristicTree is a handle that owns the the nodes, it is move-only + /** Prevent copy construction */ + HeuristicTree(const HeuristicTree &) = delete; + /** Prevent copy assignment */ + HeuristicTree &operator=(const HeuristicTree &) = delete; + /** Move constructor */ + HeuristicTree(HeuristicTree &&other) noexcept = default; + /** Move assignment */ + HeuristicTree &operator=(HeuristicTree &&other) noexcept = default; + + /** Query a leaf value given a gemm shape + * + * @tparam T Leaf value type + * @param shape A @ref GEMMShape for the query + * @return std::pair Outcome contains bool, signalling if the query succeeded or not + */ + template + std::pair query(GEMMShape shape) const; + + /** Add a leaf node + * + * @tparam T Leaf value type + * @param id Leaf node ID + * @param leaf_value Leaf node value + * @return bool If the addition succeeded or not + */ + template + bool add_leaf(NodeID id, T leaf_value); + /** Add a branch node + * + * @param id Branch node ID + * @param cond Branch node @ref Condition + * @param true_node True node's ID + * @param false_node False node's ID + * @return bool If the addition succeeded or not + */ + bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node); + + /** Get tree ID + * @return TreeID + */ + TreeID id() const + { + return _id; + } + + /** Get tree index + * @return Index + */ + Index index() const + { + return std::make_tuple(_heuristic_type, _ip_target, _data_type); + } + + /** Check if tree is valid + * @return bool + */ + bool check(); + +private: + static constexpr size_t _max_query_depth{ 1000 }; // Maximum depth of query + static constexpr size_t _max_num_nodes{ 100000 }; // Maximum number of nodes contained by the tree + static constexpr NodeID _root{ 0 }; // Root tree ID + +private: + bool check_if_structurally_correct() const; + +private: + TreeID _id; /**< Heuristic tree ID */ + HeuristicType _heuristic_type; /**< Heuristic type */ + std::string _ip_target; /**< IP target associated with the tree */ + DataType _data_type; /**< Data type associated with the tree */ + std::map> _tree; /**< Tree representation */ +}; +} // namespace mlgo + +} // namespace arm_compute + +#endif //SRC_MLGO_HEURISTICTREE_H \ No newline at end of file diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.cpp b/src/runtime/CL/mlgo/MLGOHeuristics.cpp new file mode 100644 index 0000000000..ec99b50488 --- /dev/null +++ b/src/runtime/CL/mlgo/MLGOHeuristics.cpp @@ -0,0 +1,242 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/CL/mlgo/MLGOHeuristics.h" +#include "arm_compute/core/Log.h" +#include "src/runtime/CL/mlgo/MLGOParser.h" + +#include + +namespace arm_compute +{ +namespace mlgo +{ +bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs) +{ + return std::tie(lhs.m0, lhs.n0, lhs.k0) == std::tie(rhs.m0, rhs.n0, rhs.k0); +} +bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs) +{ + return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs, + rhs.export_cl_image); +} +bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs) +{ + return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0, + rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image); +} + +constexpr size_t MLGOHeuristics::_max_num_trees; + +MLGOHeuristics::MLGOHeuristics() + : _indices{}, _trees{}, _tree_valid{}, _valid{ false } +{ +} + +std::pair MLGOHeuristics::query_gemm_type(Query query) const +{ + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm type"); + const auto invalid = GEMMType::RESHAPED; + if(!_valid) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); + return { false, invalid }; + } + auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type); + GEMMShape shape_query{ query.m, query.n, query.k, query.b }; + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return { false, invalid }; + } + return _trees.at(index).query(shape_query); +} +std::pair MLGOHeuristics::query_gemm_config_native(Query query) const +{ + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config native"); + const auto invalid = GEMMConfigNative{}; + if(!_valid) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); + return { false, invalid }; + } + auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type); + GEMMShape shape_query{ query.m, query.n, query.k, query.b }; + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return { false, invalid }; + } + return _trees.at(index).query(shape_query); +} +std::pair MLGOHeuristics::query_gemm_config_reshaped_only_rhs(Query query) const +{ + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config reshaped only rhs"); + const auto invalid = GEMMConfigReshapedOnlyRHS{}; + if(!_valid) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); + return { false, invalid }; + } + auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type); + GEMMShape shape_query{ query.m, query.n, query.k, query.b }; + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return { false, invalid }; + } + return _trees.at(index).query(shape_query); +} +std::pair MLGOHeuristics::query_gemm_config_reshaped(Query query) const +{ + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config reshaped"); + const auto invalid = GEMMConfigReshaped{}; + if(!_valid) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead"); + return { false, invalid }; + } + auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type); + GEMMShape shape_query{ query.m, query.n, query.k, query.b }; + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return { false, invalid }; + } + return _trees.at(index).query(shape_query); +} + +bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id) +{ + bool status; + HeuristicTree *tree{ nullptr }; + std::tie(status, tree) = get_heuristic_tree(id); + if(!status) + { + return status; + } + status = tree->check(); + if(!status) + { + return status; + } + _tree_valid[id] = true; + return true; +} + +bool MLGOHeuristics::check_all() const +{ + // Tree validities are already checked and cached. + bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v) + { + return !v.second; + }) + == _tree_valid.end(); + if(!all_trees_are_checked) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo"); + return false; + } + + // Other top level checks... + + return true; +} + +std::pair MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id) +{ + if(_indices.find(id) == _indices.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id); + return std::make_pair(false, nullptr); + } + const auto index = _indices[id]; + + if(_trees.find(index) == _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index"); + return std::make_pair(false, nullptr); + } + auto &t = _trees[index]; + + return std::make_pair(true, &t); +} + +bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t) +{ + if(_indices.size() >= _max_num_trees) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees); + return false; + } + // PRE: correctness of t is guaranteed by the tree construction process + // Ensure unique id + const auto id = t.id(); + if(_indices.find(id) != _indices.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id); + return false; + } + + // Ensure unique index + const auto index = t.index(); + if(_trees.find(index) != _trees.end()) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists"); + return false; + } + + _indices[id] = index; + _trees[index] = std::move(t); + _tree_valid[id] = false; + return true; +} + +bool MLGOHeuristics::reload_from_file(const std::string &filename) +{ + std::ifstream fs; + fs.exceptions(std::ifstream::badbit); + fs.open(filename, std::ios::in); + if(!fs.is_open()) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str()); + return _valid = false; + } + return reload_from_stream(fs); +} + +bool MLGOHeuristics::reload_from_stream(std::istream &in) +{ + auto parsed = parser::parse_mlgo(in); + if(!parsed.first) + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead"); + return _valid = false; + } + *this = std::move(parsed.second); + ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully"); + return _valid = true; +} + +} // namespace mlgo +} // namespace arm_compute \ No newline at end of file diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.h b/src/runtime/CL/mlgo/MLGOHeuristics.h new file mode 100644 index 0000000000..02e8111b6e --- /dev/null +++ b/src/runtime/CL/mlgo/MLGOHeuristics.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_RUNTIME_MLGO_MLGOHEURISTICS_H +#define SRC_RUNTIME_MLGO_MLGOHEURISTICS_H + +#include "src/runtime/CL/mlgo/Common.h" +#include "src/runtime/CL/mlgo/HeuristicTree.h" + +#include +#include +#include +#include +namespace arm_compute +{ +namespace mlgo +{ +/** Query interface */ +struct Query +{ + std::string ip_target; /**< The name of the IP target */ + DataType data_type; /**< Data type */ + unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ + unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ + unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ + unsigned int b; /**< Batch size */ +}; + +bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs); +bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs); + +/** MLGOHeuristics for configuring GEMM kernels */ +class MLGOHeuristics +{ +public: + /** Constructor */ + MLGOHeuristics(); + /** Query the gemm type + * + * @param[in] query Query + * + * @return std::pair signals if the query succeeded or failed + */ + std::pair query_gemm_type(Query) const; + /** Query the gemm configuration for native kernel + * + * @param[in] query Query + * + * @return std::pair bool signals if the query succeeded or failed + */ + std::pair query_gemm_config_native(Query query) const; + /** Query the gemm configuration for reshaped only rhs kernel + * + * @param[in] query Query + * + * @return std::pair bool signals if the query succeeded or failed + */ + std::pair query_gemm_config_reshaped_only_rhs(Query) const; + /** Query the gemm configuration for reshaped kernel + * + * @param[in] query Query + * + * @return std::pair bool signals if the query succeeded or failed + */ + std::pair query_gemm_config_reshaped(Query) const; + /** (Re)Load the heuristics from reading a dotmlgo file + * + * @param[in] filename Path to the dotmlgo file + * + * @return bool Signals if the reload succeeded or failed + */ + bool reload_from_file(const std::string &filename); + /** (Re)Load the heuristics from reading an input stream + * + * @param[in] istream Istream containing mlgo heuristics + * + * @return bool Signals if the reload succeeded or failed + */ + bool reload_from_stream(std::istream &istream); + + /** Get the heuristic tree from tree id + * + * @param[in] id Tree id. + * + * @return HeuristicTree& + */ + std::pair get_heuristic_tree(HeuristicTree::TreeID id); + /** Add a heuristic tree + * @param t Heuristic tree to be added + */ + bool add_heuristic_tree(HeuristicTree &&t); + + /** Check the validity of the heuristic tree. + * + * @param id ID of the tree to be checked + * + * @return bool + */ + bool check_heuristic_tree(HeuristicTree::TreeID id); + + /** Check the overall validity of the heuristics. + * @return bool + */ + bool check_all() const; + +private: + static constexpr size_t _max_num_trees{ 100 }; /**< Max number of trees that can be added*/ + +private: + // There exists a one-to-one mappipng between TreeID and Index, either can be used to identify a @ref HeuristicTree + std::map _indices; /**< A mapping from TreeID to Index */ + std::map _trees; /**< A mapping from Index to HeuristicTree */ + std::map _tree_valid; /**< Result cache of the tree validity checks */ + bool _valid; /**< Overall validity */ +}; + +} // namespace mlgo +} // namespace arm_compute +#endif //SRC_MLGO_MLGOHEURISTICS_H \ No newline at end of file diff --git a/src/runtime/CL/mlgo/MLGOParser.cpp b/src/runtime/CL/mlgo/MLGOParser.cpp new file mode 100644 index 0000000000..625739e450 --- /dev/null +++ b/src/runtime/CL/mlgo/MLGOParser.cpp @@ -0,0 +1,812 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/CL/mlgo/MLGOParser.h" +#include "arm_compute/core/Log.h" +#include "src/runtime/CL/mlgo/Utils.h" + +#include + +#define CHECK(parser_expr, valid_var) \ + (parser_expr); \ + if(!valid_var) \ + return; + +#define CHECK_DEFAULT(parser_expr, valid_var, default_val) \ + (parser_expr); \ + if(!valid_var) \ + return default_val; + +#ifdef ARM_COMPUTE_LOGGING_ENABLED + +#define FAIL_WITH_MSG(valid_var, pos, msg) \ + std::stringstream ss; \ + ss << "MLGOParser Error: " << pos << " " << msg; \ + ARM_COMPUTE_LOG_INFO_MSG_CORE(ss.str().c_str()); \ + valid_var = false; \ + return; + +#define FAIL_WITH_MSG_DEFAULT(valid_var, default_val, pos, msg) \ + std::stringstream ss; \ + ss << "MLGOParser Error: " << pos << " " << msg; \ + ARM_COMPUTE_LOG_INFO_MSG_CORE(ss.str().c_str()); \ + valid_var = false; \ + return default_val; + +#define LOG_TOKEN_POS(tokens, pos_var) \ + const auto pos_var = tokens.current_pos(); + +#else // ARM_COMPUTE_LOGGING_ENABLED + +#define FAIL_WITH_MSG(valid_var, pos, msg) \ + valid_var = false; \ + return; + +#define FAIL_WITH_MSG_DEFAULT(valid_var, default_val, pos, msg) \ + valid_var = false; \ + return default_val; + +#define LOG_TOKEN_POS(tokens, pos_var) + +#endif // ARM_COMPUTE_LOGGING_ENABLED +namespace +{ +void ltrim(std::string &str) +{ + str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char ch) + { + return !std::isspace(ch); + })); +} + +void rtrim(std::string &str) +{ + str.erase(std::find_if(str.rbegin(), str.rend(), [](char ch) + { + return !std::isspace(ch); + }).base(), + str.end()); +} + +void trim(std::string &str) +{ + ltrim(str); + rtrim(str); +} +} // namespace + +namespace arm_compute +{ +namespace mlgo +{ +namespace parser +{ +enum class ComparatorType +{ + Enum, + Num, + Var +}; + +TokenStream::TokenStream(std::istream &s, const std::string &delims) + : _delims{ delims }, _istream{ s }, _tokens{}, _lookahead_pos{} +{ + read(); +} + +TokenStream::operator bool() const +{ + ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty"); + return !reached_end(); +} + +Token TokenStream::take() +{ + ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty"); + Token t = _tokens.front(); + _tokens.pop_front(); + if(_tokens.empty()) + { + read(); + } + return t; +} +Token TokenStream::peek(size_t i) +{ + ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty"); + ARM_COMPUTE_ERROR_ON_MSG(i >= max_look_ahead, "TokenStream: Exceeding max look ahead"); + // NOTE: If i exceeds the stream (_istream.eof()), read() automatically appends a End token at the end + while(_istream && _tokens.size() <= i) + { + read(); + } + size_t ind = std::min(i, _tokens.size() - 1); + return _tokens[ind]; +} + +void advance(CharPosition &pos, char ch) +{ + if(ch == '\n') + { + pos.ln += 1; + pos.col = 0; + } + else + { + pos.col += 1; + } +} +void rewind(CharPosition &pos) +{ + pos.col -= 1; +} +void TokenStream::read() +{ + char ch; + // Skip any leading space and delim characters + do + { + // Reached eof + if(!_istream.get(ch)) + { + if(!reached_end()) + { + _tokens.emplace_back(TokenType::End, "", _lookahead_pos); + } + return; + } + advance(_lookahead_pos, ch); + } + while(std::isspace(ch) || is_delim(ch)); + // Read chars until we hit a delim or eof + auto orig_pos = _lookahead_pos; + auto tok = recognize_tok(ch); + rewind(orig_pos); + tok.pos = orig_pos; + // Trim leading and trailing white spaces + trim(tok.value); + _tokens.push_back(tok); +} + +Token TokenStream::recognize_tok(char ch) +{ + if(ch == '[') + { + return Token{ TokenType::L_List, "", _lookahead_pos }; + } + else if(ch == ']') + { + return Token{ TokenType::R_List, "", _lookahead_pos }; + } + else if(ch == '.') + { + return float_after_dp_st(std::string{ ch }); + } + else if(std::isdigit(ch)) + { + return num_st(std::string{ ch }); + } + else + { + return text_st(std::string{ ch }); + } +} + +Token TokenStream::num_st(std::string value) +{ + char ch{}; + while(_istream.get(ch)) + { + advance(_lookahead_pos, ch); + if(ch == '.') + { + return float_after_dp_st(value + ch); + } + else if(!std::isdigit(ch)) + { + if(!is_delim(ch) && !std::isspace(ch)) + { + rewind(_lookahead_pos); + _istream.unget(); + } + break; + } + value += ch; + } + return Token{ TokenType::Int, value, _lookahead_pos }; +} + +Token TokenStream::float_after_dp_st(std::string value) +{ + char ch{}; + while(_istream.get(ch)) + { + advance(_lookahead_pos, ch); + if(!std::isdigit(ch)) + { + if(!is_delim(ch) && !std::isspace(ch)) + { + rewind(_lookahead_pos); + _istream.unget(); + } + break; + } + value += ch; + } + return Token{ TokenType::Float, value, _lookahead_pos }; +} + +Token TokenStream::text_st(std::string value) +{ + char ch{}; + while(_istream.get(ch)) + { + advance(_lookahead_pos, ch); + if(is_delim(ch)) + { + break; + } + if(ch == '[' || ch == ']') + { + rewind(_lookahead_pos); + _istream.unget(); + break; + } + value += ch; + } + return Token{ TokenType::Text, value, _lookahead_pos }; +} + +bool TokenStream::reached_end() const +{ + return _tokens.size() == 1 && _tokens.front().type == TokenType::End; +} + +bool TokenStream::is_delim(char ch) const +{ + return _delims.find(ch) != std::string::npos; +} + +void end(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + auto tok = in.take(); + if(tok.type != TokenType::End) + { + FAIL_WITH_MSG(valid, pos, "Unexpected token at the end of stream"); + } +} + +bool bool_val(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + auto tok = in.take(); + if(tok.type != TokenType::Int) + { + FAIL_WITH_MSG_DEFAULT(valid, false, pos, "Expect bool or int token"); + } + bool val{}; + std::stringstream(tok.value) >> val; + return val; +} + +int int_val(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + auto tok = in.take(); + if(tok.type != TokenType::Int) + { + FAIL_WITH_MSG_DEFAULT(valid, -1, pos, "Expect int token"); + } + int val{}; + std::stringstream(tok.value) >> val; + return val; +} + +unsigned int uint_val(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + int val = CHECK_DEFAULT(int_val(in, valid), valid, 0); + if(val < 0) + { + FAIL_WITH_MSG_DEFAULT(valid, 0, pos, "Expect unsigned int token"); + } + return static_cast(val); +} + +float float_val(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + auto tok = in.take(); + if(tok.type != TokenType::Float) + { + FAIL_WITH_MSG_DEFAULT(valid, 0.f, pos, "Expect float token"); + } + float val{}; + std::stringstream(tok.value) >> val; + return val; +} + +std::string text_val(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + auto tok = in.take(); + if(tok.type != TokenType::Text || tok.value.empty()) + { + FAIL_WITH_MSG_DEFAULT(valid, "", pos, "Expect a non-empty text token"); + } + return tok.value; +} + +bool accept_text(TokenStream &in, const std::string &c_str, bool take = true) +{ + auto tok = in.peek(); + if(tok.type == TokenType::Text && tok.value == c_str) + { + if(take) + { + in.take(); + } + return true; + } + return false; +} + +void expect_text(TokenStream &in, const std::string &str, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(!accept_text(in, str)) + { + FAIL_WITH_MSG(valid, pos, std::string("Expect text token: ") + str); + } +} + +bool accept_l_list(TokenStream &in) +{ + auto tok = in.peek(); + if(tok.type == TokenType::L_List) + { + in.take(); + return true; + } + return false; +} + +void expect_l_list(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(!accept_l_list(in)) + { + FAIL_WITH_MSG(valid, pos, "Expect '['"); + } +} + +bool accept_r_list(TokenStream &in) +{ + auto tok = in.peek(); + if(tok.type == TokenType::R_List) + { + in.take(); + return true; + } + return false; +} + +void expect_r_list(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(!accept_r_list(in)) + { + FAIL_WITH_MSG(valid, pos, "Expect ']'"); + } +} + +ConditionalOp conditional_op(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "<=")) + { + return ConditionalOp::LE; + } + else if(accept_text(in, ">=")) + { + return ConditionalOp::GE; + } + else if(accept_text(in, "==")) + { + return ConditionalOp::EQ; + } + else if(accept_text(in, "<")) + { + return ConditionalOp::LT; + } + else if(accept_text(in, ">")) + { + return ConditionalOp::GT; + } + else + { + FAIL_WITH_MSG_DEFAULT(valid, ConditionalOp::EQ, pos, "Expect conditional op"); + } +} + +void gemm_version(TokenStream &in, bool &valid) +{ + CHECK(expect_text(in, "gemm-version", valid), valid); + CHECK(expect_l_list(in, valid), valid); + CHECK(uint_val(in, valid), valid); + CHECK(uint_val(in, valid), valid); + CHECK(uint_val(in, valid), valid); + CHECK(expect_r_list(in, valid), valid); +} + +void ip_type(TokenStream &in, bool &valid) +{ + CHECK(expect_text(in, "ip-type", valid), valid); + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "gpu")) + { + ; + } + else if(accept_text(in, "cpu")) + { + ; + } + else + { + FAIL_WITH_MSG(valid, pos, "Expect ip type"); + } +} + +void header(TokenStream &in, bool &valid) +{ + CHECK(expect_text(in, "
", valid), valid); + CHECK(gemm_version(in, valid), valid); + CHECK(ip_type(in, valid), valid); + CHECK(expect_text(in, "
", valid), valid); +} + +DataType data_type(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "f16")) + { + return DataType::F16; + } + else if(accept_text(in, "f32")) + { + return DataType::F32; + } + else if(accept_text(in, "qasymm8")) + { + return DataType::QASYMM8; + } + else + { + FAIL_WITH_MSG_DEFAULT(valid, DataType::QASYMM8, pos, "Expect data type"); + } +} + +ComparatorType comparator_type(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "var")) + { + return ComparatorType::Var; + } + else if(accept_text(in, "num")) + { + return ComparatorType::Num; + } + else if(accept_text(in, "enum")) + { + return ComparatorType::Enum; + } + else + { + FAIL_WITH_MSG_DEFAULT(valid, ComparatorType::Num, pos, "Expect comparator type"); + } +} + +HeuristicType heuristic_type(TokenStream &in, bool &valid, bool take = true) +{ + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "gemm-type", take)) + { + return HeuristicType::GEMM_Type; + } + else if(accept_text(in, "gemm-config-native", take)) + { + return HeuristicType::GEMM_Config_Native; + } + else if(accept_text(in, "gemm-config-reshaped-only-rhs", take)) + { + return HeuristicType::GEMM_Config_Reshaped_Only_RHS; + } + else if(accept_text(in, "gemm-config-reshaped", take)) + { + return HeuristicType::GEMM_Config_Reshaped; + } + else + { + FAIL_WITH_MSG_DEFAULT(valid, HeuristicType::GEMM_Config_Reshaped, pos, "Expect heuristic type"); + } +} + +void expect_heuristic_type(TokenStream &in, HeuristicType expected_ht, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + auto ht = CHECK(heuristic_type(in, valid, false), valid); + if(ht != expected_ht) + { + FAIL_WITH_MSG(valid, pos, "Unexpected heuristic type"); + } + CHECK(heuristic_type(in, valid, true), valid); +} + +GEMMType gemm_type(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "native")) + { + return GEMMType::NATIVE; + } + else if(accept_text(in, "reshaped-only-rhs")) + { + return GEMMType::RESHAPED_ONLY_RHS; + } + else if(accept_text(in, "reshaped")) + { + return GEMMType::RESHAPED; + } + else + { + FAIL_WITH_MSG_DEFAULT(valid, GEMMType::RESHAPED_ONLY_RHS, pos, "Expect gemm type"); + } +} + +GEMMConfigNative gemm_config_native(TokenStream &in, bool &valid) +{ + const auto invalid_val = GEMMConfigNative{}; + CHECK_DEFAULT(expect_l_list(in, valid), valid, invalid_val); + const auto m0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val); + return GEMMConfigNative{ m0, n0, k0 }; +} + +GEMMConfigReshapedOnlyRHS gemm_config_reshaped_only_rhs(TokenStream &in, bool &valid) +{ + const auto invalid_val = GEMMConfigReshapedOnlyRHS{}; + CHECK_DEFAULT(expect_l_list(in, valid), valid, invalid_val); + const auto m0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto h0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto ir = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); + const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); + const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); + CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val); + return GEMMConfigReshapedOnlyRHS{ m0, n0, k0, h0, ir, tr, ex }; +} + +GEMMConfigReshaped gemm_config_reshaped(TokenStream &in, bool &valid) +{ + const auto invalid_val = GEMMConfigReshaped{}; + CHECK_DEFAULT(expect_l_list(in, valid), valid, invalid_val); + const auto m0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto v0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto h0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val); + const auto il = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); + const auto ir = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); + const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); + const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val); + CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val); + return GEMMConfigReshaped{ m0, n0, k0, v0, h0, il, ir, tr, ex }; +} + +void gpu_priority(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "best-performance")) + { + ; + } + else if(accept_text(in, "best-memory-usage")) + { + ; + } + else + { + FAIL_WITH_MSG(valid, pos, "Expect gpu priority"); + } +} + +void gpu_behavior(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "static")) + { + ; + } + else if(accept_text(in, "dynamic")) + { + ; + } + else + { + FAIL_WITH_MSG(valid, pos, "Expect ip type"); + } +} + +void free_vars(TokenStream &in, bool &valid) +{ + CHECK(expect_l_list(in, valid), valid); + while(!accept_r_list(in)) + { + CHECK(text_val(in, valid), valid); + } +} + +void heuristics_table_entry(TokenStream &in, MLGOHeuristics &h, bool &valid) +{ + const auto id = CHECK(uint_val(in, valid), valid); + const auto ip = CHECK(text_val(in, valid), valid); + CHECK(uint_val(in, valid), valid); // Num cores + const auto dt = CHECK(data_type(in, valid), valid); + CHECK(gpu_priority(in, valid), valid); + CHECK(gpu_behavior(in, valid), valid); + const auto ht = CHECK(heuristic_type(in, valid), valid); + CHECK(free_vars(in, valid), valid); + HeuristicTree t(id, ht, ip, dt); + valid = CHECK(h.add_heuristic_tree(std::move(t)), valid); +} + +void heuristics_table(TokenStream &in, MLGOHeuristics &h, bool &valid) +{ + CHECK(expect_text(in, "", valid), valid); + while(!accept_text(in, "")) + { + CHECK(heuristics_table_entry(in, h, valid), valid); + } +} + +Condition condition(TokenStream &in, bool &valid) +{ + LOG_TOKEN_POS(in, pos); + // NOTE: Only simplified Conditions are accepted, which means the lhs comparator type is fixed to Var and that of + // the rhs is fixed to Num (float) + const auto invalid_val = Condition{}; + const auto l_t = CHECK_DEFAULT(comparator_type(in, valid), valid, invalid_val); + const auto l_v = CHECK_DEFAULT(text_val(in, valid), valid, invalid_val); + const auto c_o = CHECK_DEFAULT(conditional_op(in, valid), valid, invalid_val); + const auto r_t = CHECK_DEFAULT(comparator_type(in, valid), valid, invalid_val); + const auto r_v = CHECK_DEFAULT(float_val(in, valid), valid, invalid_val); + if(l_t != ComparatorType::Var || r_t != ComparatorType::Num) + { + FAIL_WITH_MSG_DEFAULT(valid, invalid_val, pos, "Only accept LHS type to be Var (string) and RHS type to be Num (float)"); + } + return Condition{ l_v, c_o, r_v }; +} + +void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid) +{ + CHECK(expect_text(in, "", valid), valid); + HeuristicTree *t = nullptr; + std::tie(valid, t) = CHECK(h.get_heuristic_tree(tree_id), valid); + const HeuristicType t_heuristic_type = std::get<0>(t->index()); + while(!accept_text(in, "")) + { + LOG_TOKEN_POS(in, pos); + if(accept_text(in, "b")) + { + // Branch node + const auto id = CHECK(uint_val(in, valid), valid); + const auto cond = CHECK(condition(in, valid), valid); + const auto t_id = CHECK(uint_val(in, valid), valid); + const auto f_id = CHECK(uint_val(in, valid), valid); + valid = CHECK(t->add_branch(id, cond, t_id, f_id), valid); + } + else if(accept_text(in, "l")) + { + // Leaf node + const auto id = CHECK(uint_val(in, valid), valid); + // NOTE: Heuristic type within each tree appears to be redundant (same information can be obtained from the + // heuristic table). For now it remains as a step for validation. + LOG_TOKEN_POS(in, pos); + CHECK(expect_heuristic_type(in, t_heuristic_type, valid), valid); + switch(t_heuristic_type) + { + case HeuristicType::GEMM_Type: + { + const auto g_type = CHECK(gemm_type(in, valid), valid); + valid = CHECK(t->add_leaf(id, g_type), valid); + break; + } + case HeuristicType::GEMM_Config_Native: + { + const auto g_c = CHECK(gemm_config_native(in, valid), valid); + valid = CHECK(t->add_leaf(id, g_c), valid); + break; + } + case HeuristicType::GEMM_Config_Reshaped_Only_RHS: + { + const auto g_c = CHECK(gemm_config_reshaped_only_rhs(in, valid), valid); + valid = CHECK(t->add_leaf(id, g_c), valid); + break; + } + case HeuristicType::GEMM_Config_Reshaped: + { + const auto g_c = CHECK(gemm_config_reshaped(in, valid), valid); + valid = CHECK(t->add_leaf(id, g_c), valid); + break; + } + default: + { + FAIL_WITH_MSG(valid, pos, "Unexpected heuristic type"); + } + } + } + else + { + FAIL_WITH_MSG(valid, pos, "Expect tree node type"); + } + } + // Perform semantic checks in the middle of parsing so that it can fail fast should there be any invalidities + valid = CHECK(h.check_heuristic_tree(tree_id), valid); +} + +MLGOHeuristics mlgo(TokenStream &in, bool &valid) +{ + MLGOHeuristics h; + CHECK_DEFAULT(header(in, valid), valid, h); + CHECK_DEFAULT(heuristics_table(in, h, valid), valid, h); + while(accept_text(in, " parse_mlgo(std::istream &in) +{ + auto tokens = TokenStream(in); + bool valid = true; + auto h = mlgo(tokens, valid); + return std::make_pair(std::move(valid), std::move(h)); +} +} // namespace parser +} // namespace mlgo +} // namespace arm_compute + +#undef CHECK +#undef CHECK_DEFAULT +#undef FAIL_WITH_MSG +#undef FAIL_WITH_MSG_DEFAULT \ No newline at end of file diff --git a/src/runtime/CL/mlgo/MLGOParser.h b/src/runtime/CL/mlgo/MLGOParser.h new file mode 100644 index 0000000000..e4a31c1f55 --- /dev/null +++ b/src/runtime/CL/mlgo/MLGOParser.h @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_MLGO_MLGOPARSER_H +#define SRC_MLGO_MLGOPARSER_H + +#include "src/runtime/CL/mlgo/MLGOHeuristics.h" + +#include +#include +#include +#include + +/** A DotMLGO file parser (LL(k) parser) + * + * The grammar of DotMLGO is defined as the following ENBF: + * + * delim = "," | "\n"; // Note that delimiters are omitted from the definition below + * + * mlgo = header, heuristics-table, {heuristic-tree}; + * + * header = "
", gemm-version, ip-type, "
"; + * gemm-version = "gemm-version", "[", int, int, int, "]"; + * ip-type = "ip-type", ("gpu" | "cpu"); + * + * heiristics-table = "", {heuristics-table-entry}, ""; + * heuristics-table-entry = entry-id, ip-name, num-cores, data-type, gpu-priority, gpu-behavior, heuristic-type, free-vars; + * entry-id = int; + * ip-name = char-sequence; + * num-cores = int; + * data-type = "f32" | "f16" | "qasymm8"; + * gpu-priority = "best-performance" | "best-memory-usage"; + * gpu-behavior = "static" | "dynamic"; + * heuristic-type = "gemm-type" | "gemm-config-native" | "gemm-config-reshaped-only-rhs" | "gemm-config-reshaped"; + * free-vars = "[", {char-sequence}, "]"; + * + * heuristic-tree = "", {tree-node}, ""; + * tree-node = branch-node | leaf-node; + * branch-node = "b", entry-id, lhs-type, lhs-value, conditional-op, rhs-type, rhs-value, true-node, false-node; + * lhs-type = comparator-type; + * lhs-value = comparator-value; + * rhs-type = comparator-type; + * rhs-value = comparator-value; + * comparator-type = "var" | "num" | "enum"; + * comparator-value = char-sequence | float; + * conditional-op = "<" | "<=" | "==" | ">=" | ">"; + * true-node = entry-id; + * false-node = entry-id; + * leaf-node = "l", entry-id, heuristic-type, leaf-value; + * leaf-value = gemm-type | gemm-config-native | gemm-config-reshaped-only-rhs | gemm-config-reshaped + * gemm-type = "native" | "reshaped-only-rhs" | "reshaped"; + * gemm-config-native = "[", int, int, int, "]"; + * gemm-config-reshaped-only-rhs = "[", int, int, int, int, bool, bool, bool, "]"; + * gemm-config-reshaped = "[", int, int, int, int, int, bool, bool, bool, bool, "]"; + */ + +namespace arm_compute +{ +namespace mlgo +{ +namespace parser +{ +/** Type of Token */ +enum class TokenType +{ + L_List = '[', /**< List open */ + R_List = ']', /**< List close */ + Int, /**< Integral */ + Float, /**< Floating */ + Text, /**< Text/String */ + End, /**< End of stream */ +}; + +struct CharPosition +{ + bool operator==(const CharPosition &other) const + { + return ln == other.ln && col == other.col; + } + + size_t ln{ 0 }; + size_t col{ 0 }; +}; + +/** Token */ +struct Token +{ + Token(TokenType t, std::string v, CharPosition pos) + : type{ t }, value{ v }, pos{ pos } + { + } + + bool operator==(const Token &other) const + { + return type == other.type && value == other.value && pos == other.pos; + } + + TokenType type; /**< Token type */ + std::string value; /**< Token value */ + CharPosition pos; +}; + +/** A stream of token */ +class TokenStream +{ + // NOTE: _tokens is never empty. The end of token stream is signalled by the End Token +public: + static constexpr size_t max_look_ahead = 10; + +public: + /** Constructor + * + * @param[in] s Input stream + * @param[in] delims Delimiter characters packed in a string. Each char from the string can be used as a delim on its own + */ + TokenStream(std::istream &s, const std::string &delims = ",\n"); + + /** Check if there're more (non-End) Tokens + * @return true If there are more tokens + * @return false If reached end of stream (only End token) + */ + explicit operator bool() const; + + /** Get and pop off the current token + * + * @return Token + */ + Token take(); + + /** Peek the next ith token + * + * @param[in] i The next ith token. i < @ref max_look_ahead. + * + * @return Token + */ + Token peek(size_t i = 0); + + /** Get the position of the current token + * + * @return CharPosition + */ + CharPosition current_pos() const + { + return _tokens.front().pos; + } + +private: + void read(); + + Token recognize_tok(char ch); + + Token num_st(std::string value = ""); + + Token float_after_dp_st(std::string value = ""); + + Token text_st(std::string value = ""); + + bool reached_end() const; + + bool is_delim(char ch) const; + + std::string _delims; + std::istream &_istream; + std::deque _tokens; + CharPosition _lookahead_pos; +}; + +/** Parse and construct a @ref MLGOHeuristics from input stream + * + * @param[in] in Input stream + * + * @return MLGOHeuristics + */ +std::pair parse_mlgo(std::istream &in); + +} // namespace parser +} // namespace mlgo +} // namespace arm_compute +#endif //SRC_MLGO_MLGOPARSER_H \ No newline at end of file diff --git a/src/runtime/CL/mlgo/Utils.cpp b/src/runtime/CL/mlgo/Utils.cpp new file mode 100644 index 0000000000..bd06bdf521 --- /dev/null +++ b/src/runtime/CL/mlgo/Utils.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/CL/mlgo/Utils.h" + +namespace arm_compute +{ +namespace mlgo +{ +std::ostream &operator<<(std::ostream &os, const GEMMConfigNative &config) +{ + return os << "Native:{" + << "m0: " << config.m0 << ", " + << "n0: " << config.n0 << ", " + << "k0: " << config.k0 << ", " + << "}"; +} +std::ostream &operator<<(std::ostream &os, const GEMMConfigReshapedOnlyRHS &config) +{ + return os << "ReshapedOnlyRHS:{" + << "m0: " << config.m0 << ", " + << "n0: " << config.n0 << ", " + << "k0: " << config.k0 << ", " + << "h0: " << config.h0 << ", " + << "interleave_rhs: " << config.interleave_rhs << ", " + << "transpose_rhs: " << config.transpose_rhs << ", " + << "export_cl_image: " << config.export_cl_image + << "}"; +} +std::ostream &operator<<(std::ostream &os, const GEMMConfigReshaped &config) +{ + return os << "Reshaped:{" + << "m0: " << config.m0 << ", " + << "n0: " << config.n0 << ", " + << "k0: " << config.k0 << ", " + << "v0: " << config.v0 << ", " + << "h0: " << config.h0 << ", " + << "interleave_lhs: " << config.interleave_lhs << ", " + << "interleave_rhs: " << config.interleave_rhs << ", " + << "transpose_rhs: " << config.transpose_rhs << ", " + << "export_cl_image: " << config.export_cl_image + << "}"; +} +std::ostream &operator<<(std::ostream &os, const HeuristicType &ht) +{ + switch(ht) + { + case HeuristicType::GEMM_Type: + { + os << "GEMM_Type"; + break; + } + case HeuristicType::GEMM_Config_Reshaped_Only_RHS: + { + os << "GEMM_Config_Reshaped_Only_RHS"; + break; + } + case HeuristicType::GEMM_Config_Reshaped: + { + os << "GEMM_Config_Reshaped"; + break; + } + default: + { + os << "Unknown"; + break; + } + } + return os; +} +std::ostream &operator<<(std::ostream &os, const DataType &dt) +{ + switch(dt) + { + case DataType::F32: + { + os << "F32"; + break; + } + case DataType::F16: + { + os << "F16"; + break; + } + case DataType::QASYMM8: + { + os << "QASYMM8"; + break; + } + default: + { + os << "Unknown"; + break; + } + } + return os; +} +std::ostream &operator<<(std::ostream &os, const HeuristicTree::Index &index) +{ + HeuristicType ht; + std::string ip; + DataType dt; + std::tie(ht, ip, dt) = index; + os << "Index("; + os << "HeuristicType=" << ht << ","; + os << "IP=" << ip << ","; + os << "DataType=" << dt; + os << ")"; + return os; +} + +namespace parser +{ +std::ostream &operator<<(std::ostream &os, CharPosition pos) +{ + os << "(Ln: " << pos.ln << ", Col: " << pos.col << ")"; + return os; +} +} // namespace parser + +} // namespace mlgo + +} // namespace arm_compute \ No newline at end of file diff --git a/src/runtime/CL/mlgo/Utils.h b/src/runtime/CL/mlgo/Utils.h new file mode 100644 index 0000000000..2e324dd439 --- /dev/null +++ b/src/runtime/CL/mlgo/Utils.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_MLGO_UTILS_H +#define SRC_MLGO_UTILS_H + +#include "src/runtime/CL/mlgo/Common.h" +#include "src/runtime/CL/mlgo/HeuristicTree.h" +#include "src/runtime/CL/mlgo/MLGOParser.h" + +#include + +namespace arm_compute +{ +namespace mlgo +{ +std::ostream &operator<<(std::ostream &os, const GEMMConfigNative &config); +std::ostream &operator<<(std::ostream &os, const GEMMConfigReshapedOnlyRHS &config); +std::ostream &operator<<(std::ostream &os, const GEMMConfigReshaped &config); +std::ostream &operator<<(std::ostream &os, const HeuristicType &ht); +std::ostream &operator<<(std::ostream &os, const DataType &dt); +std::ostream &operator<<(std::ostream &os, const HeuristicTree::Index &index); +namespace parser +{ +std::ostream &operator<<(std::ostream &os, CharPosition tok); +} // namespace parser +} // namespace mlgo +} // namespace arm_compute + +#endif //SRC_MLGO_UTILS_H \ No newline at end of file diff --git a/tests/SConscript b/tests/SConscript index 92cf47b5c6..041ed8f548 100644 --- a/tests/SConscript +++ b/tests/SConscript @@ -284,4 +284,4 @@ if test_env['benchmark_examples']: Depends(arm_compute_benchmark_examples, arm_compute_test_framework) Depends(arm_compute_benchmark_examples, arm_compute_lib) Default(arm_compute_benchmark_examples) - Export('arm_compute_benchmark_examples') + Export('arm_compute_benchmark_examples') \ No newline at end of file diff --git a/tests/validation/CL/UNIT/MLGOHeuristics.cpp b/tests/validation/CL/UNIT/MLGOHeuristics.cpp new file mode 100644 index 0000000000..895b4b51d0 --- /dev/null +++ b/tests/validation/CL/UNIT/MLGOHeuristics.cpp @@ -0,0 +1,461 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/CL/mlgo/MLGOHeuristics.h" +#include "src/runtime/CL/mlgo/Utils.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Macros.h" + +using namespace arm_compute::mlgo; + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +TEST_SUITE(CL) +TEST_SUITE(UNIT) +TEST_SUITE(MLGOHeuristics) +TEST_CASE(CorrectDotMLGOShouldLoadCorrectly, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n] + 1, g71 , 8, f16, best-performance, static, gemm-config-reshaped-only-rhs, [m,n,k,n] + 2, g76 , 8, f16, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + b , 0, var, m, ==, num, 10., 1, 2 + l , 1, gemm-type, reshaped + b , 2, var, r_mn, >=, num, 2., 3, 6 + b , 3, var, n, >=, num, 200., 4, 5 + l , 4, gemm-type, reshaped-only-rhs + l , 5, gemm-type, reshaped + l , 6, gemm-type, reshaped-only-rhs + + + b ,0,var, n, >, num, 100., 1, 4 + b ,1,var, r_mnk, <=, num, 20., 2, 3 + l ,2,gemm-config-reshaped-only-rhs, [4, 4,4,2,1,0,1] + l ,3,gemm-config-reshaped-only-rhs,[ 2, 2,4,2,1,1, 1 ] + b ,4,var, n, >=, num, 199.12, 5, 6 + l ,5,gemm-config-reshaped-only-rhs, [1, 4,3,4,0,0,0] + l ,6,gemm-config-reshaped-only-rhs, [5, 4,4,5,1,1,0] + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + heuristics.reload_from_stream(ss); + + ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 10, 1024, 20, 1 }).second == GEMMType::RESHAPED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 201, 5, 1 }).second == GEMMType::RESHAPED_ONLY_RHS, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 200, 199, 16 }).second == GEMMType::RESHAPED_ONLY_RHS, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 199, 512, 4 }).second == GEMMType::RESHAPED, framework::LogLevel::ERRORS); + + ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 100, 1024, 20, 32 }).second == GEMMConfigReshapedOnlyRHS{ 4, 4, 4, 2, true, false, true }), + framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 100, 1024, 20, 32 }).second == GEMMConfigReshapedOnlyRHS{ 4, 4, 4, 2, true, false, true }), + framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 128, 101, 20, 1 }).second == GEMMConfigReshapedOnlyRHS{ 2, 2, 4, 2, true, true, true }), + framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 400, 100, 512, 1 }).second == GEMMConfigReshapedOnlyRHS{ 5, 4, 4, 5, true, true, false }), + framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 400, 100, 512, 1 }).second == GEMMConfigReshapedOnlyRHS{ 5, 4, 4, 5, true, true, false }), + framework::LogLevel::ERRORS); + + ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F16, 100, 100, 20, 32 }).second == GEMMConfigReshaped{ 4, 2, 4, 2, 8, true, false, true, false }), + framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F16, 128, 512, 1024, 1 }).second == GEMMConfigReshaped{ 4, 2, 4, 2, 8, true, false, true, false }), + framework::LogLevel::ERRORS); +} + +TEST_CASE(InvalidDotmlgoSyntaxShouldReturnInvalidStatus, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,pu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} + +TEST_SUITE(InvalidDotmlgoSemanticsShouldReturnInvalidStatus) +// If the semantics errors are local to some trees instead of the entire heuristics, an alternative is to simply +// ignore/remove those invalid trees. However the reason why we choose to throw, thus invalidating the entire +// heuristics is that if there are some invalid trees, the quality of the dotmlgo is called into question even if +// the rest of the trees are semantically valid, and they could severely degrade the performance of GEMM. Therefore +// this "all or nothing" approach when it comes to dotmlgo correctness is safer and more defensive. + +// Also note that the semantic error of the tree only refers to those that obstruct its evaluation and thus query, +// (e.g. invalid tree structure, unsupported features etc.) instead of those affecting the desired outcome +// (usually in terms of final GEMM performance, e.g. the effectiveness of the decision tree) + +// In the future we might want to check the content of the exceptions as well. But right now it suffices to only +// know that it throws exactly when it needs to. +TEST_CASE(MismatchesBetweenHeuristicsTableEntriesAndHeuristicTrees, framework::DatasetMode::ALL) +{ + { + // Mismatching number of entries 1 + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + // NOTE: This case might throw an internal error as the tree inserted by the heuristics-table cannot not be checked + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); + } + + { + // Mismatching number of entries 2 + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); + } + + { + // Mismatching info + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n] + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); + } +} + +TEST_CASE(RepeatedHeuristicsTableEntriesId, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + 0, g71 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} + +TEST_CASE(RepeatedHeuristicsTableEntriesIndex, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + 1, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} + +TEST_CASE(RepeatedHeuristicTreesId, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + 1, g71 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + + l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} +TEST_CASE(EmptyTree, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} + +TEST_CASE(InvalidTreeMissingRoot, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + b ,2, var, m, ==, num, 10., 3, 4 + l ,3,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + l ,4,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} +TEST_CASE(InvalidTreeMissingNodes, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + b ,0, var, m, ==, num, 10., 1, 2 + l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} +TEST_CASE(InvalidTreeRepeatedNodeIds, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + b ,0, var, m, ==, num, 10., 1, 2 + l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + l ,1,gemm-config-reshaped,[1,2,4,2,8,1,0,1,0] + l ,2,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} +TEST_CASE(InvalidTreeDisjointNodes, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + b ,0, var, m, ==, num, 10., 1, 2 + l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + l ,2,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0] + + b ,4, var, n, ==, num, 10., 5, 6 + l ,5,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + l ,6,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0] + + l ,7,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} +TEST_CASE(InvalidTreeLoop, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + b ,0, var, m, ==, num, 10., 0, 1 + l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} +TEST_CASE(InvalidTreeCycle, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + b ,0, var, m, ==, num, 10., 1, 5 + b ,1, var, n, ==, num, 10., 2, 3 + l ,2,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + b ,3, var, k, ==, num, 10., 0, 4 + l ,4,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + l ,5,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} +TEST_CASE(InvalidTreeInvalidFeatures, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n] + + + b ,0, var, magic_feature, ==, num, 10., 1, 2 + l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + l ,2,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0] + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); +} +TEST_SUITE_END() // InvalidDotmlgoSemanticsShouldReturnInvalidStatus + +TEST_CASE(InvalidUsageOfHeuristicsShouldReturnInvalidStatus, framework::DatasetMode::ALL) +{ + std::string mlgo_str = R"_( +
+ gemm-version, [1,2,1] + ip-type,gpu +
+ + 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n] + + + b , 0, var, m, ==, num, 10., 1, 2 + l , 1, gemm-type, reshaped + b , 2, var, r_mn, >=, num, 2., 3, 6 + b , 3, var, n, >=, num, 200., 4, 5 + l , 4, gemm-type, reshaped-only-rhs + l , 5, gemm-type, reshaped + l , 6, gemm-type, reshaped-only-rhs + + )_"; + std::stringstream ss(mlgo_str); + MLGOHeuristics heuristics; + ARM_COMPUTE_EXPECT(heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS); + + // Querying unavailable heuristic type should return invalid Status + ARM_COMPUTE_EXPECT(!heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F32, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS); + // Querying unavailable ip target should return invalid Status + ARM_COMPUTE_EXPECT(!heuristics.query_gemm_type(Query{ "g77", DataType::F32, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS); + // Querying unavailable data type should return invalid Status + ARM_COMPUTE_EXPECT(!heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g76", DataType::QASYMM8, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS); +} +TEST_SUITE_END() // MLGOHeuristics +TEST_SUITE_END() // UNIT +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute -- cgit v1.2.1