aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2021-01-08 15:16:02 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2021-02-02 22:05:34 +0000
commit7061eb283969f9a020c08349454447564e4dd5b3 (patch)
treecf11604aaf57ee86ff6e4980a3ddf95bb49167e8
parent74a142c11ec0b2f2b3fe1feb3fdfd98e9190762e (diff)
downloadComputeLibrary-7061eb283969f9a020c08349454447564e4dd5b3.tar.gz
Implement MLGO module
* Implement MLGOHeuristics which provides a query and a loading interface * Implement a top-down parser MLGOParser for parsing dotmlgo * Add validation tests for MLGOHeuristics Resolves COMPMID-3840, COMPMID-3841 Signed-off-by: SiCong Li <sicong.li@arm.com> Change-Id: Iae96d2779524b2dd83623d1a3a30ef57823ae084 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4941 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--Android.bp4
-rw-r--r--SConscript1
-rw-r--r--src/runtime/CL/mlgo/Common.h81
-rw-r--r--src/runtime/CL/mlgo/HeuristicTree.cpp253
-rw-r--r--src/runtime/CL/mlgo/HeuristicTree.h198
-rw-r--r--src/runtime/CL/mlgo/MLGOHeuristics.cpp242
-rw-r--r--src/runtime/CL/mlgo/MLGOHeuristics.h139
-rw-r--r--src/runtime/CL/mlgo/MLGOParser.cpp812
-rw-r--r--src/runtime/CL/mlgo/MLGOParser.h199
-rw-r--r--src/runtime/CL/mlgo/Utils.cpp143
-rw-r--r--src/runtime/CL/mlgo/Utils.h50
-rw-r--r--tests/SConscript2
-rw-r--r--tests/validation/CL/UNIT/MLGOHeuristics.cpp461
13 files changed, 2584 insertions, 1 deletions
diff --git a/Android.bp b/Android.bp
index 663adaedfc..f9761e2352 100644
--- a/Android.bp
+++ b/Android.bp
@@ -611,6 +611,10 @@ cc_library_static {
"src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp",
"src/runtime/CL/gemm/CLGEMMDefaultTypeMidgard.cpp",
"src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp",
+ "src/runtime/CL/mlgo/HeuristicTree.cpp",
+ "src/runtime/CL/mlgo/MLGOHeuristics.cpp",
+ "src/runtime/CL/mlgo/MLGOParser.cpp",
+ "src/runtime/CL/mlgo/Utils.cpp",
"src/runtime/CL/tuners/BifrostTuner.cpp",
"src/runtime/CL/tuners/CLTuningParametersList.cpp",
"src/runtime/CL/tuners/MidgardTuner.cpp",
diff --git a/SConscript b/SConscript
index 7e09240768..307a0983d8 100644
--- a/SConscript
+++ b/SConscript
@@ -226,6 +226,7 @@ if env['opencl']:
runtime_files += Glob('src/runtime/CL/tuners/*.cpp')
runtime_files += Glob('src/runtime/gpu/cl/*.cpp')
runtime_files += Glob('src/runtime/gpu/cl/operators/*.cpp')
+ runtime_files += Glob('src/runtime/CL/mlgo/*.cpp')
graph_files += Glob('src/graph/backends/CL/*.cpp')
diff --git a/src/runtime/CL/mlgo/Common.h b/src/runtime/CL/mlgo/Common.h
new file mode 100644
index 0000000000..a2d3ec8241
--- /dev/null
+++ b/src/runtime/CL/mlgo/Common.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_MLGO_COMMON_H
+#define SRC_MLGO_COMMON_H
+
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/CL/CLTypes.h"
+
+namespace arm_compute
+{
+namespace mlgo
+{
+/** Types of Heuristic (tree) */
+enum class HeuristicType
+{
+ GEMM_Type, /**< About the type of gemm */
+ GEMM_Config_Native, /**< About the gemm config for native kernel */
+ GEMM_Config_Reshaped_Only_RHS, /**< About the gemm config for reshaped only rhs kernel */
+ GEMM_Config_Reshaped /**< About the gemm config for reshaped kernel */
+};
+
+using GEMMType = CLGEMMKernelType;
+
+/** GEMM Configuration for Native kernel */
+struct GEMMConfigNative
+{
+ unsigned int m0; /**< Number of rows processed by the matrix multiplication */
+ unsigned int n0; /**< Number of columns processed by the matrix multiplication */
+ unsigned int k0; /**< Number of partial accumulations performed by the matrix multiplication */
+};
+
+/** GEMM Configuration for Reshaped Only RHS kernel */
+struct GEMMConfigReshapedOnlyRHS
+{
+ unsigned int m0; /**< Number of rows processed by the matrix multiplication */
+ unsigned int n0; /**< Number of columns processed by the matrix multiplication */
+ unsigned int k0; /**< Number of partial accumulations performed by the matrix multiplication */
+ unsigned int h0; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
+ bool interleave_rhs; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */
+ bool transpose_rhs; /**< True if the (k0xn0) block has to be transposed before been stored */
+ bool export_cl_image; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */
+};
+
+/** GEMM Configuration for Reshaped kernel */
+struct GEMMConfigReshaped
+{
+ unsigned int m0; /**< Number of rows processed by the matrix multiplication */
+ unsigned int n0; /**< Number of columns processed by the matrix multiplication */
+ unsigned int k0; /**< Number of partial accumulations performed by the matrix multiplication */
+ unsigned int v0; /**< Number of vertical blocks of size (m0xk0) stored on the same output row */
+ unsigned int h0; /**< Number of horizontal blocks of size (k0xn0) stored on the same output row */
+ bool interleave_lhs; /**< True if the v0 (m0xk0) blocks have to be interleaved in the output row */
+ bool interleave_rhs; /**< True if the h0 (k0xn0) blocks have to be interleaved in the output row */
+ bool transpose_rhs; /**< True if the (k0xn0) block has to be transposed before been stored */
+ bool export_cl_image; /**< True if the reshaped rhs has to be exported to cl_image. n0 must be equal to 4 */
+};
+
+} // namespace mlgo
+} // namespace arm_compute
+#endif // SRC_MLGO_COMMON_H \ No newline at end of file
diff --git a/src/runtime/CL/mlgo/HeuristicTree.cpp b/src/runtime/CL/mlgo/HeuristicTree.cpp
new file mode 100644
index 0000000000..65219998cb
--- /dev/null
+++ b/src/runtime/CL/mlgo/HeuristicTree.cpp
@@ -0,0 +1,253 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/runtime/CL/mlgo/HeuristicTree.h"
+#include "arm_compute/core/Log.h"
+
+#include <algorithm>
+#include <deque>
+#include <set>
+namespace arm_compute
+{
+namespace mlgo
+{
+namespace
+{
+bool evaluate(GEMMShape shape, Condition cond)
+{
+ // PRE: all features and ConditionalOps are valid
+ constexpr float eps = 0.0001f;
+ // Calculate all secondary features
+ std::vector<std::pair<std::string, float>> cond_values
+ {
+ { "m", static_cast<float>(shape.m) },
+ { "n", static_cast<float>(shape.n) },
+ { "k", static_cast<float>(shape.k) },
+ { "b", static_cast<float>(shape.b) },
+ { "r_mn", static_cast<float>(shape.m) / shape.n },
+ { "r_mk", static_cast<float>(shape.m) / shape.k },
+ { "r_nk", static_cast<float>(shape.n) / shape.k },
+ { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
+ { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
+ };
+ auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
+ [&cond](decltype(*cond_values.begin()) it)
+ {
+ return it.first == cond.feature;
+ });
+
+ ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
+ const float cond_value = cond_value_pair_it->second;
+ switch(cond.op)
+ {
+ case ConditionalOp::LT:
+ {
+ return cond_value < cond.threshold;
+ }
+ case ConditionalOp::LE:
+ {
+ return cond_value <= cond.threshold;
+ }
+ case ConditionalOp::GT:
+ {
+ return cond_value > cond.threshold;
+ }
+ case ConditionalOp::GE:
+ {
+ return cond_value >= cond.threshold;
+ }
+ case ConditionalOp::EQ:
+ default:
+ {
+ return std::abs(cond_value - cond.threshold) < eps;
+ }
+ }
+}
+
+} // namespace
+
+constexpr size_t HeuristicTree::_max_num_nodes;
+constexpr size_t HeuristicTree::_max_query_depth;
+constexpr HeuristicTree::NodeID HeuristicTree::_root;
+
+HeuristicTree::HeuristicTree()
+ : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
+{
+}
+
+HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
+ : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
+{
+}
+
+template <typename T>
+std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
+{
+ // Root ID = 0;
+ auto cur_node = _tree.at(_root).get();
+ size_t depth = 0;
+ while(cur_node->type() != NodeType::Leaf)
+ {
+ if(depth > _max_query_depth)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
+ return std::make_pair(false, T{});
+ }
+ ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
+ auto br_node = dynamic_cast<BranchNode *>(cur_node);
+ if(evaluate(shape, br_node->condition))
+ {
+ cur_node = _tree.at(br_node->true_node).get();
+ }
+ else
+ {
+ cur_node = _tree.at(br_node->false_node).get();
+ }
+ ++depth;
+ }
+ ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
+ auto l_node = dynamic_cast<LeafNode<T> *>(cur_node);
+ return std::make_pair(true, l_node->value);
+}
+
+template <typename T>
+bool HeuristicTree::add_leaf(NodeID id, T val)
+{
+ if(_tree.size() >= _max_num_nodes)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
+ return false;
+ }
+ if(_tree.find(id) != _tree.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
+ return false;
+ }
+ _tree[id] = std::make_unique<LeafNode<T>>(id, val);
+ return true;
+}
+
+bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
+{
+ if(_tree.size() >= _max_num_nodes)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
+ return false;
+ }
+
+ const std::set<std::string> supported_features =
+ {
+ "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
+ };
+ const auto orig_feature = cond.feature;
+ std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
+ {
+ return std::tolower(c);
+ });
+ if(supported_features.find(cond.feature) == supported_features.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
+ return false;
+ }
+
+ if(_tree.find(id) != _tree.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
+ return false;
+ }
+ _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
+ return true;
+}
+
+bool HeuristicTree::check_if_structurally_correct() const
+{
+ std::set<NodeID> visited;
+ std::deque<NodeID> to_visit{ _root };
+
+ while(!to_visit.empty())
+ {
+ auto id = to_visit.front();
+ to_visit.pop_front();
+ if(_tree.find(id) == _tree.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
+ return false;
+ }
+ auto not_seen_before = visited.insert(id);
+ if(!not_seen_before.second)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
+ return false;
+ }
+ auto cur_node = _tree.at(id).get();
+ if(cur_node->type() == NodeType::Branch)
+ {
+ auto br_node = dynamic_cast<BranchNode *>(cur_node);
+ to_visit.push_back(br_node->true_node);
+ to_visit.push_back(br_node->false_node);
+ }
+ }
+ if(visited.size() != _tree.size())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
+ return false;
+ }
+ return true;
+}
+
+bool HeuristicTree::check()
+{
+ if(_tree.empty())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
+ return false;
+ }
+ if(_tree.find(_root) == _tree.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
+ return false;
+ }
+ return check_if_structurally_correct();
+}
+
+/** Explicit template instantiation @relates HeuristicTree */
+template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shape) const;
+/** Explicit template instantiation @relates HeuristicTree */
+template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
+/** Explicit template instantiation @relates HeuristicTree */
+template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
+/** Explicit template instantiation @relates HeuristicTree */
+template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
+
+/** Explicit template instantiation @relates HeuristicTree */
+template bool HeuristicTree::add_leaf(NodeID id, GEMMType val);
+/** Explicit template instantiation @relates HeuristicTree */
+template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val);
+/** Explicit template instantiation @relates HeuristicTree */
+template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val);
+/** Explicit template instantiation @relates HeuristicTree */
+template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val);
+
+} // namespace mlgo
+
+} // namespace arm_compute
diff --git a/src/runtime/CL/mlgo/HeuristicTree.h b/src/runtime/CL/mlgo/HeuristicTree.h
new file mode 100644
index 0000000000..64d79ffaa1
--- /dev/null
+++ b/src/runtime/CL/mlgo/HeuristicTree.h
@@ -0,0 +1,198 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_MLGO_HEURISTICTREE_H
+#define SRC_MLGO_HEURISTICTREE_H
+
+#include "arm_compute/core/Types.h"
+#include "src/runtime/CL/mlgo/Common.h"
+
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+
+namespace arm_compute
+{
+namespace mlgo
+{
+/** Conditional ops */
+enum class ConditionalOp
+{
+ EQ, /**< Equal */
+ LT, /**< Less than */
+ LE, /**< Less than or equal to */
+ GT, /**< Greater than */
+ GE, /**< Greater than or equal to */
+};
+
+/** A branch condition expression evaluating: feature op threshold */
+struct Condition
+{
+ std::string feature; /**< Feature name */
+ ConditionalOp op; /**< Condtional op */
+ float threshold; /**< Threshold value */
+};
+
+/** GEMM Shape used for query */
+struct GEMMShape
+{
+ unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
+ unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
+ unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
+ unsigned int b; /**< Batch size */
+};
+
+/** A binary decision tree based heuristic */
+class HeuristicTree
+{
+public:
+ using NodeID = size_t;
+ using TreeID = size_t;
+ using Index = std::tuple<HeuristicType, std::string, DataType>;
+ enum class NodeType
+ {
+ Branch,
+ Leaf
+ };
+ struct Node
+ {
+ virtual NodeType type() const = 0;
+ virtual ~Node() = default;
+ };
+
+ struct BranchNode : public Node
+ {
+ BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
+ : id{ id }, condition{ cond }, true_node{ t_node }, false_node{ f_node }
+ {
+ }
+ NodeType type() const override
+ {
+ return NodeType::Branch;
+ }
+ NodeID id;
+ Condition condition;
+ NodeID true_node;
+ NodeID false_node;
+ };
+
+ template <typename T>
+ struct LeafNode : public Node
+ {
+ LeafNode(NodeID id, T val)
+ : id{ id }, value{ val }
+ {
+ }
+ NodeType type() const override
+ {
+ return NodeType::Leaf;
+ }
+ NodeID id;
+ T value;
+ };
+
+public:
+ /** Constructor */
+ HeuristicTree();
+ /** Constructor */
+ HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type);
+ // Since the HeuristicTree is a handle that owns the the nodes, it is move-only
+ /** Prevent copy construction */
+ HeuristicTree(const HeuristicTree &) = delete;
+ /** Prevent copy assignment */
+ HeuristicTree &operator=(const HeuristicTree &) = delete;
+ /** Move constructor */
+ HeuristicTree(HeuristicTree &&other) noexcept = default;
+ /** Move assignment */
+ HeuristicTree &operator=(HeuristicTree &&other) noexcept = default;
+
+ /** Query a leaf value given a gemm shape
+ *
+ * @tparam T Leaf value type
+ * @param shape A @ref GEMMShape for the query
+ * @return std::pair<bool, T> Outcome contains bool, signalling if the query succeeded or not
+ */
+ template <typename T>
+ std::pair<bool, T> query(GEMMShape shape) const;
+
+ /** Add a leaf node
+ *
+ * @tparam T Leaf value type
+ * @param id Leaf node ID
+ * @param leaf_value Leaf node value
+ * @return bool If the addition succeeded or not
+ */
+ template <typename T>
+ bool add_leaf(NodeID id, T leaf_value);
+ /** Add a branch node
+ *
+ * @param id Branch node ID
+ * @param cond Branch node @ref Condition
+ * @param true_node True node's ID
+ * @param false_node False node's ID
+ * @return bool If the addition succeeded or not
+ */
+ bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node);
+
+ /** Get tree ID
+ * @return TreeID
+ */
+ TreeID id() const
+ {
+ return _id;
+ }
+
+ /** Get tree index
+ * @return Index
+ */
+ Index index() const
+ {
+ return std::make_tuple(_heuristic_type, _ip_target, _data_type);
+ }
+
+ /** Check if tree is valid
+ * @return bool
+ */
+ bool check();
+
+private:
+ static constexpr size_t _max_query_depth{ 1000 }; // Maximum depth of query
+ static constexpr size_t _max_num_nodes{ 100000 }; // Maximum number of nodes contained by the tree
+ static constexpr NodeID _root{ 0 }; // Root tree ID
+
+private:
+ bool check_if_structurally_correct() const;
+
+private:
+ TreeID _id; /**< Heuristic tree ID */
+ HeuristicType _heuristic_type; /**< Heuristic type */
+ std::string _ip_target; /**< IP target associated with the tree */
+ DataType _data_type; /**< Data type associated with the tree */
+ std::map<NodeID, std::unique_ptr<Node>> _tree; /**< Tree representation */
+};
+} // namespace mlgo
+
+} // namespace arm_compute
+
+#endif //SRC_MLGO_HEURISTICTREE_H \ No newline at end of file
diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.cpp b/src/runtime/CL/mlgo/MLGOHeuristics.cpp
new file mode 100644
index 0000000000..ec99b50488
--- /dev/null
+++ b/src/runtime/CL/mlgo/MLGOHeuristics.cpp
@@ -0,0 +1,242 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/runtime/CL/mlgo/MLGOHeuristics.h"
+#include "arm_compute/core/Log.h"
+#include "src/runtime/CL/mlgo/MLGOParser.h"
+
+#include <fstream>
+
+namespace arm_compute
+{
+namespace mlgo
+{
+bool operator==(const GEMMConfigNative &lhs, const GEMMConfigNative &rhs)
+{
+ return std::tie(lhs.m0, lhs.n0, lhs.k0) == std::tie(rhs.m0, rhs.n0, rhs.k0);
+}
+bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs)
+{
+ return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.h0, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.h0, rhs.interleave_rhs, rhs.transpose_rhs,
+ rhs.export_cl_image);
+}
+bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs)
+{
+ return std::tie(lhs.m0, lhs.n0, lhs.k0, lhs.v0, lhs.h0, lhs.interleave_lhs, lhs.interleave_rhs, lhs.transpose_rhs, lhs.export_cl_image) == std::tie(rhs.m0, rhs.n0, rhs.k0, rhs.v0, rhs.h0,
+ rhs.interleave_lhs, rhs.interleave_rhs, rhs.transpose_rhs, rhs.export_cl_image);
+}
+
+constexpr size_t MLGOHeuristics::_max_num_trees;
+
+MLGOHeuristics::MLGOHeuristics()
+ : _indices{}, _trees{}, _tree_valid{}, _valid{ false }
+{
+}
+
+std::pair<bool, GEMMType> MLGOHeuristics::query_gemm_type(Query query) const
+{
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm type");
+ const auto invalid = GEMMType::RESHAPED;
+ if(!_valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
+ return { false, invalid };
+ }
+ auto index = std::make_tuple(HeuristicType::GEMM_Type, query.ip_target, query.data_type);
+ GEMMShape shape_query{ query.m, query.n, query.k, query.b };
+ if(_trees.find(index) == _trees.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
+ return { false, invalid };
+ }
+ return _trees.at(index).query<GEMMType>(shape_query);
+}
+std::pair<bool, GEMMConfigNative> MLGOHeuristics::query_gemm_config_native(Query query) const
+{
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config native");
+ const auto invalid = GEMMConfigNative{};
+ if(!_valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
+ return { false, invalid };
+ }
+ auto index = std::make_tuple(HeuristicType::GEMM_Config_Native, query.ip_target, query.data_type);
+ GEMMShape shape_query{ query.m, query.n, query.k, query.b };
+ if(_trees.find(index) == _trees.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
+ return { false, invalid };
+ }
+ return _trees.at(index).query<GEMMConfigNative>(shape_query);
+}
+std::pair<bool, GEMMConfigReshapedOnlyRHS> MLGOHeuristics::query_gemm_config_reshaped_only_rhs(Query query) const
+{
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config reshaped only rhs");
+ const auto invalid = GEMMConfigReshapedOnlyRHS{};
+ if(!_valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
+ return { false, invalid };
+ }
+ auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped_Only_RHS, query.ip_target, query.data_type);
+ GEMMShape shape_query{ query.m, query.n, query.k, query.b };
+ if(_trees.find(index) == _trees.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
+ return { false, invalid };
+ }
+ return _trees.at(index).query<GEMMConfigReshapedOnlyRHS>(shape_query);
+}
+std::pair<bool, GEMMConfigReshaped> MLGOHeuristics::query_gemm_config_reshaped(Query query) const
+{
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics querying gemm config reshaped");
+ const auto invalid = GEMMConfigReshaped{};
+ if(!_valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Invalid DotMLGO. Use default heuristics instead");
+ return { false, invalid };
+ }
+ auto index = std::make_tuple(HeuristicType::GEMM_Config_Reshaped, query.ip_target, query.data_type);
+ GEMMShape shape_query{ query.m, query.n, query.k, query.b };
+ if(_trees.find(index) == _trees.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
+ return { false, invalid };
+ }
+ return _trees.at(index).query<GEMMConfigReshaped>(shape_query);
+}
+
+bool MLGOHeuristics::check_heuristic_tree(HeuristicTree::TreeID id)
+{
+ bool status;
+ HeuristicTree *tree{ nullptr };
+ std::tie(status, tree) = get_heuristic_tree(id);
+ if(!status)
+ {
+ return status;
+ }
+ status = tree->check();
+ if(!status)
+ {
+ return status;
+ }
+ _tree_valid[id] = true;
+ return true;
+}
+
+bool MLGOHeuristics::check_all() const
+{
+ // Tree validities are already checked and cached.
+ bool all_trees_are_checked = std::find_if(_tree_valid.begin(), _tree_valid.end(), [](auto v)
+ {
+ return !v.second;
+ })
+ == _tree_valid.end();
+ if(!all_trees_are_checked)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Missing checks on some trees. Make sure to call check_heuristic_tree after each tree is completed. This could also indicate there are no trees in the dotmlgo");
+ return false;
+ }
+
+ // Other top level checks...
+
+ return true;
+}
+
+std::pair<bool, HeuristicTree *> MLGOHeuristics::get_heuristic_tree(HeuristicTree::TreeID id)
+{
+ if(_indices.find(id) == _indices.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot find tree with id %zu", id);
+ return std::make_pair(false, nullptr);
+ }
+ const auto index = _indices[id];
+
+ if(_trees.find(index) == _trees.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot find tree index");
+ return std::make_pair(false, nullptr);
+ }
+ auto &t = _trees[index];
+
+ return std::make_pair(true, &t);
+}
+
+bool MLGOHeuristics::add_heuristic_tree(HeuristicTree &&t)
+{
+ if(_indices.size() >= _max_num_trees)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the max number of trees allowed: %zu", _max_num_trees);
+ return false;
+ }
+ // PRE: correctness of t is guaranteed by the tree construction process
+ // Ensure unique id
+ const auto id = t.id();
+ if(_indices.find(id) != _indices.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add redundant trees; tree id %zu already exists", id);
+ return false;
+ }
+
+ // Ensure unique index
+ const auto index = t.index();
+ if(_trees.find(index) != _trees.end())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("Cannot add redundant trees; tree index already exists");
+ return false;
+ }
+
+ _indices[id] = index;
+ _trees[index] = std::move(t);
+ _tree_valid[id] = false;
+ return true;
+}
+
+bool MLGOHeuristics::reload_from_file(const std::string &filename)
+{
+ std::ifstream fs;
+ fs.exceptions(std::ifstream::badbit);
+ fs.open(filename, std::ios::in);
+ if(!fs.is_open())
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot open DotMLGO file %s. Use default heuristics instead", filename.c_str());
+ return _valid = false;
+ }
+ return reload_from_stream(fs);
+}
+
+bool MLGOHeuristics::reload_from_stream(std::istream &in)
+{
+ auto parsed = parser::parse_mlgo(in);
+ if(!parsed.first)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO parsing failed. Use default heuristics instead");
+ return _valid = false;
+ }
+ *this = std::move(parsed.second);
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("DotMLGO loaded successfully");
+ return _valid = true;
+}
+
+} // namespace mlgo
+} // namespace arm_compute \ No newline at end of file
diff --git a/src/runtime/CL/mlgo/MLGOHeuristics.h b/src/runtime/CL/mlgo/MLGOHeuristics.h
new file mode 100644
index 0000000000..02e8111b6e
--- /dev/null
+++ b/src/runtime/CL/mlgo/MLGOHeuristics.h
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_RUNTIME_MLGO_MLGOHEURISTICS_H
+#define SRC_RUNTIME_MLGO_MLGOHEURISTICS_H
+
+#include "src/runtime/CL/mlgo/Common.h"
+#include "src/runtime/CL/mlgo/HeuristicTree.h"
+
+#include <iostream>
+#include <map>
+#include <string>
+#include <utility>
+namespace arm_compute
+{
+namespace mlgo
+{
+/** Query interface */
+struct Query
+{
+ std::string ip_target; /**< The name of the IP target */
+ DataType data_type; /**< Data type */
+ unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
+ unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
+ unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
+ unsigned int b; /**< Batch size */
+};
+
+bool operator==(const GEMMConfigReshapedOnlyRHS &lhs, const GEMMConfigReshapedOnlyRHS &rhs);
+bool operator==(const GEMMConfigReshaped &lhs, const GEMMConfigReshaped &rhs);
+
+/** MLGOHeuristics for configuring GEMM kernels */
+class MLGOHeuristics
+{
+public:
+ /** Constructor */
+ MLGOHeuristics();
+ /** Query the gemm type
+ *
+ * @param[in] query Query
+ *
+ * @return std::pair<bool, GEMMType> signals if the query succeeded or failed
+ */
+ std::pair<bool, GEMMType> query_gemm_type(Query) const;
+ /** Query the gemm configuration for native kernel
+ *
+ * @param[in] query Query
+ *
+ * @return std::pair<bool, GEMMConfigNative> bool signals if the query succeeded or failed
+ */
+ std::pair<bool, GEMMConfigNative> query_gemm_config_native(Query query) const;
+ /** Query the gemm configuration for reshaped only rhs kernel
+ *
+ * @param[in] query Query
+ *
+ * @return std::pair<bool, GEMMConfigReshapedOnlyRHS> bool signals if the query succeeded or failed
+ */
+ std::pair<bool, GEMMConfigReshapedOnlyRHS> query_gemm_config_reshaped_only_rhs(Query) const;
+ /** Query the gemm configuration for reshaped kernel
+ *
+ * @param[in] query Query
+ *
+ * @return std::pair<bool, GEMMConfigReshaped> bool signals if the query succeeded or failed
+ */
+ std::pair<bool, GEMMConfigReshaped> query_gemm_config_reshaped(Query) const;
+ /** (Re)Load the heuristics from reading a dotmlgo file
+ *
+ * @param[in] filename Path to the dotmlgo file
+ *
+ * @return bool Signals if the reload succeeded or failed
+ */
+ bool reload_from_file(const std::string &filename);
+ /** (Re)Load the heuristics from reading an input stream
+ *
+ * @param[in] istream Istream containing mlgo heuristics
+ *
+ * @return bool Signals if the reload succeeded or failed
+ */
+ bool reload_from_stream(std::istream &istream);
+
+ /** Get the heuristic tree from tree id
+ *
+ * @param[in] id Tree id.
+ *
+ * @return HeuristicTree&
+ */
+ std::pair<bool, HeuristicTree *> get_heuristic_tree(HeuristicTree::TreeID id);
+ /** Add a heuristic tree
+ * @param t Heuristic tree to be added
+ */
+ bool add_heuristic_tree(HeuristicTree &&t);
+
+ /** Check the validity of the heuristic tree.
+ *
+ * @param id ID of the tree to be checked
+ *
+ * @return bool
+ */
+ bool check_heuristic_tree(HeuristicTree::TreeID id);
+
+ /** Check the overall validity of the heuristics.
+ * @return bool
+ */
+ bool check_all() const;
+
+private:
+ static constexpr size_t _max_num_trees{ 100 }; /**< Max number of trees that can be added*/
+
+private:
+ // There exists a one-to-one mappipng between TreeID and Index, either can be used to identify a @ref HeuristicTree
+ std::map<HeuristicTree::TreeID, HeuristicTree::Index> _indices; /**< A mapping from TreeID to Index */
+ std::map<HeuristicTree::Index, HeuristicTree> _trees; /**< A mapping from Index to HeuristicTree */
+ std::map<HeuristicTree::TreeID, bool> _tree_valid; /**< Result cache of the tree validity checks */
+ bool _valid; /**< Overall validity */
+};
+
+} // namespace mlgo
+} // namespace arm_compute
+#endif //SRC_MLGO_MLGOHEURISTICS_H \ No newline at end of file
diff --git a/src/runtime/CL/mlgo/MLGOParser.cpp b/src/runtime/CL/mlgo/MLGOParser.cpp
new file mode 100644
index 0000000000..625739e450
--- /dev/null
+++ b/src/runtime/CL/mlgo/MLGOParser.cpp
@@ -0,0 +1,812 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/runtime/CL/mlgo/MLGOParser.h"
+#include "arm_compute/core/Log.h"
+#include "src/runtime/CL/mlgo/Utils.h"
+
+#include <sstream>
+
+#define CHECK(parser_expr, valid_var) \
+ (parser_expr); \
+ if(!valid_var) \
+ return;
+
+#define CHECK_DEFAULT(parser_expr, valid_var, default_val) \
+ (parser_expr); \
+ if(!valid_var) \
+ return default_val;
+
+#ifdef ARM_COMPUTE_LOGGING_ENABLED
+
+#define FAIL_WITH_MSG(valid_var, pos, msg) \
+ std::stringstream ss; \
+ ss << "MLGOParser Error: " << pos << " " << msg; \
+ ARM_COMPUTE_LOG_INFO_MSG_CORE(ss.str().c_str()); \
+ valid_var = false; \
+ return;
+
+#define FAIL_WITH_MSG_DEFAULT(valid_var, default_val, pos, msg) \
+ std::stringstream ss; \
+ ss << "MLGOParser Error: " << pos << " " << msg; \
+ ARM_COMPUTE_LOG_INFO_MSG_CORE(ss.str().c_str()); \
+ valid_var = false; \
+ return default_val;
+
+#define LOG_TOKEN_POS(tokens, pos_var) \
+ const auto pos_var = tokens.current_pos();
+
+#else // ARM_COMPUTE_LOGGING_ENABLED
+
+#define FAIL_WITH_MSG(valid_var, pos, msg) \
+ valid_var = false; \
+ return;
+
+#define FAIL_WITH_MSG_DEFAULT(valid_var, default_val, pos, msg) \
+ valid_var = false; \
+ return default_val;
+
+#define LOG_TOKEN_POS(tokens, pos_var)
+
+#endif // ARM_COMPUTE_LOGGING_ENABLED
+namespace
+{
+void ltrim(std::string &str)
+{
+ str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char ch)
+ {
+ return !std::isspace(ch);
+ }));
+}
+
+void rtrim(std::string &str)
+{
+ str.erase(std::find_if(str.rbegin(), str.rend(), [](char ch)
+ {
+ return !std::isspace(ch);
+ }).base(),
+ str.end());
+}
+
+void trim(std::string &str)
+{
+ ltrim(str);
+ rtrim(str);
+}
+} // namespace
+
+namespace arm_compute
+{
+namespace mlgo
+{
+namespace parser
+{
+enum class ComparatorType
+{
+ Enum,
+ Num,
+ Var
+};
+
+TokenStream::TokenStream(std::istream &s, const std::string &delims)
+ : _delims{ delims }, _istream{ s }, _tokens{}, _lookahead_pos{}
+{
+ read();
+}
+
+TokenStream::operator bool() const
+{
+ ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty");
+ return !reached_end();
+}
+
+Token TokenStream::take()
+{
+ ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty");
+ Token t = _tokens.front();
+ _tokens.pop_front();
+ if(_tokens.empty())
+ {
+ read();
+ }
+ return t;
+}
+Token TokenStream::peek(size_t i)
+{
+ ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty");
+ ARM_COMPUTE_ERROR_ON_MSG(i >= max_look_ahead, "TokenStream: Exceeding max look ahead");
+ // NOTE: If i exceeds the stream (_istream.eof()), read() automatically appends a End token at the end
+ while(_istream && _tokens.size() <= i)
+ {
+ read();
+ }
+ size_t ind = std::min(i, _tokens.size() - 1);
+ return _tokens[ind];
+}
+
+void advance(CharPosition &pos, char ch)
+{
+ if(ch == '\n')
+ {
+ pos.ln += 1;
+ pos.col = 0;
+ }
+ else
+ {
+ pos.col += 1;
+ }
+}
+void rewind(CharPosition &pos)
+{
+ pos.col -= 1;
+}
+void TokenStream::read()
+{
+ char ch;
+ // Skip any leading space and delim characters
+ do
+ {
+ // Reached eof
+ if(!_istream.get(ch))
+ {
+ if(!reached_end())
+ {
+ _tokens.emplace_back(TokenType::End, "", _lookahead_pos);
+ }
+ return;
+ }
+ advance(_lookahead_pos, ch);
+ }
+ while(std::isspace(ch) || is_delim(ch));
+ // Read chars until we hit a delim or eof
+ auto orig_pos = _lookahead_pos;
+ auto tok = recognize_tok(ch);
+ rewind(orig_pos);
+ tok.pos = orig_pos;
+ // Trim leading and trailing white spaces
+ trim(tok.value);
+ _tokens.push_back(tok);
+}
+
+Token TokenStream::recognize_tok(char ch)
+{
+ if(ch == '[')
+ {
+ return Token{ TokenType::L_List, "", _lookahead_pos };
+ }
+ else if(ch == ']')
+ {
+ return Token{ TokenType::R_List, "", _lookahead_pos };
+ }
+ else if(ch == '.')
+ {
+ return float_after_dp_st(std::string{ ch });
+ }
+ else if(std::isdigit(ch))
+ {
+ return num_st(std::string{ ch });
+ }
+ else
+ {
+ return text_st(std::string{ ch });
+ }
+}
+
+Token TokenStream::num_st(std::string value)
+{
+ char ch{};
+ while(_istream.get(ch))
+ {
+ advance(_lookahead_pos, ch);
+ if(ch == '.')
+ {
+ return float_after_dp_st(value + ch);
+ }
+ else if(!std::isdigit(ch))
+ {
+ if(!is_delim(ch) && !std::isspace(ch))
+ {
+ rewind(_lookahead_pos);
+ _istream.unget();
+ }
+ break;
+ }
+ value += ch;
+ }
+ return Token{ TokenType::Int, value, _lookahead_pos };
+}
+
+Token TokenStream::float_after_dp_st(std::string value)
+{
+ char ch{};
+ while(_istream.get(ch))
+ {
+ advance(_lookahead_pos, ch);
+ if(!std::isdigit(ch))
+ {
+ if(!is_delim(ch) && !std::isspace(ch))
+ {
+ rewind(_lookahead_pos);
+ _istream.unget();
+ }
+ break;
+ }
+ value += ch;
+ }
+ return Token{ TokenType::Float, value, _lookahead_pos };
+}
+
+Token TokenStream::text_st(std::string value)
+{
+ char ch{};
+ while(_istream.get(ch))
+ {
+ advance(_lookahead_pos, ch);
+ if(is_delim(ch))
+ {
+ break;
+ }
+ if(ch == '[' || ch == ']')
+ {
+ rewind(_lookahead_pos);
+ _istream.unget();
+ break;
+ }
+ value += ch;
+ }
+ return Token{ TokenType::Text, value, _lookahead_pos };
+}
+
+bool TokenStream::reached_end() const
+{
+ return _tokens.size() == 1 && _tokens.front().type == TokenType::End;
+}
+
+bool TokenStream::is_delim(char ch) const
+{
+ return _delims.find(ch) != std::string::npos;
+}
+
+void end(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ auto tok = in.take();
+ if(tok.type != TokenType::End)
+ {
+ FAIL_WITH_MSG(valid, pos, "Unexpected token at the end of stream");
+ }
+}
+
+bool bool_val(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ auto tok = in.take();
+ if(tok.type != TokenType::Int)
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, false, pos, "Expect bool or int token");
+ }
+ bool val{};
+ std::stringstream(tok.value) >> val;
+ return val;
+}
+
+int int_val(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ auto tok = in.take();
+ if(tok.type != TokenType::Int)
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, -1, pos, "Expect int token");
+ }
+ int val{};
+ std::stringstream(tok.value) >> val;
+ return val;
+}
+
+unsigned int uint_val(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ int val = CHECK_DEFAULT(int_val(in, valid), valid, 0);
+ if(val < 0)
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, 0, pos, "Expect unsigned int token");
+ }
+ return static_cast<unsigned int>(val);
+}
+
+float float_val(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ auto tok = in.take();
+ if(tok.type != TokenType::Float)
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, 0.f, pos, "Expect float token");
+ }
+ float val{};
+ std::stringstream(tok.value) >> val;
+ return val;
+}
+
+std::string text_val(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ auto tok = in.take();
+ if(tok.type != TokenType::Text || tok.value.empty())
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, "", pos, "Expect a non-empty text token");
+ }
+ return tok.value;
+}
+
+bool accept_text(TokenStream &in, const std::string &c_str, bool take = true)
+{
+ auto tok = in.peek();
+ if(tok.type == TokenType::Text && tok.value == c_str)
+ {
+ if(take)
+ {
+ in.take();
+ }
+ return true;
+ }
+ return false;
+}
+
+void expect_text(TokenStream &in, const std::string &str, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(!accept_text(in, str))
+ {
+ FAIL_WITH_MSG(valid, pos, std::string("Expect text token: ") + str);
+ }
+}
+
+bool accept_l_list(TokenStream &in)
+{
+ auto tok = in.peek();
+ if(tok.type == TokenType::L_List)
+ {
+ in.take();
+ return true;
+ }
+ return false;
+}
+
+void expect_l_list(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(!accept_l_list(in))
+ {
+ FAIL_WITH_MSG(valid, pos, "Expect '['");
+ }
+}
+
+bool accept_r_list(TokenStream &in)
+{
+ auto tok = in.peek();
+ if(tok.type == TokenType::R_List)
+ {
+ in.take();
+ return true;
+ }
+ return false;
+}
+
+void expect_r_list(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(!accept_r_list(in))
+ {
+ FAIL_WITH_MSG(valid, pos, "Expect ']'");
+ }
+}
+
+ConditionalOp conditional_op(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "<="))
+ {
+ return ConditionalOp::LE;
+ }
+ else if(accept_text(in, ">="))
+ {
+ return ConditionalOp::GE;
+ }
+ else if(accept_text(in, "=="))
+ {
+ return ConditionalOp::EQ;
+ }
+ else if(accept_text(in, "<"))
+ {
+ return ConditionalOp::LT;
+ }
+ else if(accept_text(in, ">"))
+ {
+ return ConditionalOp::GT;
+ }
+ else
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, ConditionalOp::EQ, pos, "Expect conditional op");
+ }
+}
+
+void gemm_version(TokenStream &in, bool &valid)
+{
+ CHECK(expect_text(in, "gemm-version", valid), valid);
+ CHECK(expect_l_list(in, valid), valid);
+ CHECK(uint_val(in, valid), valid);
+ CHECK(uint_val(in, valid), valid);
+ CHECK(uint_val(in, valid), valid);
+ CHECK(expect_r_list(in, valid), valid);
+}
+
+void ip_type(TokenStream &in, bool &valid)
+{
+ CHECK(expect_text(in, "ip-type", valid), valid);
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "gpu"))
+ {
+ ;
+ }
+ else if(accept_text(in, "cpu"))
+ {
+ ;
+ }
+ else
+ {
+ FAIL_WITH_MSG(valid, pos, "Expect ip type");
+ }
+}
+
+void header(TokenStream &in, bool &valid)
+{
+ CHECK(expect_text(in, "<header>", valid), valid);
+ CHECK(gemm_version(in, valid), valid);
+ CHECK(ip_type(in, valid), valid);
+ CHECK(expect_text(in, "</header>", valid), valid);
+}
+
+DataType data_type(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "f16"))
+ {
+ return DataType::F16;
+ }
+ else if(accept_text(in, "f32"))
+ {
+ return DataType::F32;
+ }
+ else if(accept_text(in, "qasymm8"))
+ {
+ return DataType::QASYMM8;
+ }
+ else
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, DataType::QASYMM8, pos, "Expect data type");
+ }
+}
+
+ComparatorType comparator_type(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "var"))
+ {
+ return ComparatorType::Var;
+ }
+ else if(accept_text(in, "num"))
+ {
+ return ComparatorType::Num;
+ }
+ else if(accept_text(in, "enum"))
+ {
+ return ComparatorType::Enum;
+ }
+ else
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, ComparatorType::Num, pos, "Expect comparator type");
+ }
+}
+
+HeuristicType heuristic_type(TokenStream &in, bool &valid, bool take = true)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "gemm-type", take))
+ {
+ return HeuristicType::GEMM_Type;
+ }
+ else if(accept_text(in, "gemm-config-native", take))
+ {
+ return HeuristicType::GEMM_Config_Native;
+ }
+ else if(accept_text(in, "gemm-config-reshaped-only-rhs", take))
+ {
+ return HeuristicType::GEMM_Config_Reshaped_Only_RHS;
+ }
+ else if(accept_text(in, "gemm-config-reshaped", take))
+ {
+ return HeuristicType::GEMM_Config_Reshaped;
+ }
+ else
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, HeuristicType::GEMM_Config_Reshaped, pos, "Expect heuristic type");
+ }
+}
+
+void expect_heuristic_type(TokenStream &in, HeuristicType expected_ht, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ auto ht = CHECK(heuristic_type(in, valid, false), valid);
+ if(ht != expected_ht)
+ {
+ FAIL_WITH_MSG(valid, pos, "Unexpected heuristic type");
+ }
+ CHECK(heuristic_type(in, valid, true), valid);
+}
+
+GEMMType gemm_type(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "native"))
+ {
+ return GEMMType::NATIVE;
+ }
+ else if(accept_text(in, "reshaped-only-rhs"))
+ {
+ return GEMMType::RESHAPED_ONLY_RHS;
+ }
+ else if(accept_text(in, "reshaped"))
+ {
+ return GEMMType::RESHAPED;
+ }
+ else
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, GEMMType::RESHAPED_ONLY_RHS, pos, "Expect gemm type");
+ }
+}
+
+GEMMConfigNative gemm_config_native(TokenStream &in, bool &valid)
+{
+ const auto invalid_val = GEMMConfigNative{};
+ CHECK_DEFAULT(expect_l_list(in, valid), valid, invalid_val);
+ const auto m0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
+ return GEMMConfigNative{ m0, n0, k0 };
+}
+
+GEMMConfigReshapedOnlyRHS gemm_config_reshaped_only_rhs(TokenStream &in, bool &valid)
+{
+ const auto invalid_val = GEMMConfigReshapedOnlyRHS{};
+ CHECK_DEFAULT(expect_l_list(in, valid), valid, invalid_val);
+ const auto m0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto h0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto ir = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
+ const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
+ const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
+ CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
+ return GEMMConfigReshapedOnlyRHS{ m0, n0, k0, h0, ir, tr, ex };
+}
+
+GEMMConfigReshaped gemm_config_reshaped(TokenStream &in, bool &valid)
+{
+ const auto invalid_val = GEMMConfigReshaped{};
+ CHECK_DEFAULT(expect_l_list(in, valid), valid, invalid_val);
+ const auto m0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto v0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto h0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
+ const auto il = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
+ const auto ir = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
+ const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
+ const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
+ CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
+ return GEMMConfigReshaped{ m0, n0, k0, v0, h0, il, ir, tr, ex };
+}
+
+void gpu_priority(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "best-performance"))
+ {
+ ;
+ }
+ else if(accept_text(in, "best-memory-usage"))
+ {
+ ;
+ }
+ else
+ {
+ FAIL_WITH_MSG(valid, pos, "Expect gpu priority");
+ }
+}
+
+void gpu_behavior(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "static"))
+ {
+ ;
+ }
+ else if(accept_text(in, "dynamic"))
+ {
+ ;
+ }
+ else
+ {
+ FAIL_WITH_MSG(valid, pos, "Expect ip type");
+ }
+}
+
+void free_vars(TokenStream &in, bool &valid)
+{
+ CHECK(expect_l_list(in, valid), valid);
+ while(!accept_r_list(in))
+ {
+ CHECK(text_val(in, valid), valid);
+ }
+}
+
+void heuristics_table_entry(TokenStream &in, MLGOHeuristics &h, bool &valid)
+{
+ const auto id = CHECK(uint_val(in, valid), valid);
+ const auto ip = CHECK(text_val(in, valid), valid);
+ CHECK(uint_val(in, valid), valid); // Num cores
+ const auto dt = CHECK(data_type(in, valid), valid);
+ CHECK(gpu_priority(in, valid), valid);
+ CHECK(gpu_behavior(in, valid), valid);
+ const auto ht = CHECK(heuristic_type(in, valid), valid);
+ CHECK(free_vars(in, valid), valid);
+ HeuristicTree t(id, ht, ip, dt);
+ valid = CHECK(h.add_heuristic_tree(std::move(t)), valid);
+}
+
+void heuristics_table(TokenStream &in, MLGOHeuristics &h, bool &valid)
+{
+ CHECK(expect_text(in, "<heuristics-table>", valid), valid);
+ while(!accept_text(in, "</heuristics-table>"))
+ {
+ CHECK(heuristics_table_entry(in, h, valid), valid);
+ }
+}
+
+Condition condition(TokenStream &in, bool &valid)
+{
+ LOG_TOKEN_POS(in, pos);
+ // NOTE: Only simplified Conditions are accepted, which means the lhs comparator type is fixed to Var and that of
+ // the rhs is fixed to Num (float)
+ const auto invalid_val = Condition{};
+ const auto l_t = CHECK_DEFAULT(comparator_type(in, valid), valid, invalid_val);
+ const auto l_v = CHECK_DEFAULT(text_val(in, valid), valid, invalid_val);
+ const auto c_o = CHECK_DEFAULT(conditional_op(in, valid), valid, invalid_val);
+ const auto r_t = CHECK_DEFAULT(comparator_type(in, valid), valid, invalid_val);
+ const auto r_v = CHECK_DEFAULT(float_val(in, valid), valid, invalid_val);
+ if(l_t != ComparatorType::Var || r_t != ComparatorType::Num)
+ {
+ FAIL_WITH_MSG_DEFAULT(valid, invalid_val, pos, "Only accept LHS type to be Var (string) and RHS type to be Num (float)");
+ }
+ return Condition{ l_v, c_o, r_v };
+}
+
+void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
+{
+ CHECK(expect_text(in, "<heuristic", valid), valid);
+ const auto tree_id = CHECK(uint_val(in, valid), valid);
+ CHECK(expect_text(in, ">", valid), valid);
+ HeuristicTree *t = nullptr;
+ std::tie(valid, t) = CHECK(h.get_heuristic_tree(tree_id), valid);
+ const HeuristicType t_heuristic_type = std::get<0>(t->index());
+ while(!accept_text(in, "</heuristic>"))
+ {
+ LOG_TOKEN_POS(in, pos);
+ if(accept_text(in, "b"))
+ {
+ // Branch node
+ const auto id = CHECK(uint_val(in, valid), valid);
+ const auto cond = CHECK(condition(in, valid), valid);
+ const auto t_id = CHECK(uint_val(in, valid), valid);
+ const auto f_id = CHECK(uint_val(in, valid), valid);
+ valid = CHECK(t->add_branch(id, cond, t_id, f_id), valid);
+ }
+ else if(accept_text(in, "l"))
+ {
+ // Leaf node
+ const auto id = CHECK(uint_val(in, valid), valid);
+ // NOTE: Heuristic type within each tree appears to be redundant (same information can be obtained from the
+ // heuristic table). For now it remains as a step for validation.
+ LOG_TOKEN_POS(in, pos);
+ CHECK(expect_heuristic_type(in, t_heuristic_type, valid), valid);
+ switch(t_heuristic_type)
+ {
+ case HeuristicType::GEMM_Type:
+ {
+ const auto g_type = CHECK(gemm_type(in, valid), valid);
+ valid = CHECK(t->add_leaf(id, g_type), valid);
+ break;
+ }
+ case HeuristicType::GEMM_Config_Native:
+ {
+ const auto g_c = CHECK(gemm_config_native(in, valid), valid);
+ valid = CHECK(t->add_leaf(id, g_c), valid);
+ break;
+ }
+ case HeuristicType::GEMM_Config_Reshaped_Only_RHS:
+ {
+ const auto g_c = CHECK(gemm_config_reshaped_only_rhs(in, valid), valid);
+ valid = CHECK(t->add_leaf(id, g_c), valid);
+ break;
+ }
+ case HeuristicType::GEMM_Config_Reshaped:
+ {
+ const auto g_c = CHECK(gemm_config_reshaped(in, valid), valid);
+ valid = CHECK(t->add_leaf(id, g_c), valid);
+ break;
+ }
+ default:
+ {
+ FAIL_WITH_MSG(valid, pos, "Unexpected heuristic type");
+ }
+ }
+ }
+ else
+ {
+ FAIL_WITH_MSG(valid, pos, "Expect tree node type");
+ }
+ }
+ // Perform semantic checks in the middle of parsing so that it can fail fast should there be any invalidities
+ valid = CHECK(h.check_heuristic_tree(tree_id), valid);
+}
+
+MLGOHeuristics mlgo(TokenStream &in, bool &valid)
+{
+ MLGOHeuristics h;
+ CHECK_DEFAULT(header(in, valid), valid, h);
+ CHECK_DEFAULT(heuristics_table(in, h, valid), valid, h);
+ while(accept_text(in, "<heuristic", false))
+ {
+ CHECK_DEFAULT(heuristic_tree(in, h, valid), valid, h);
+ }
+ CHECK_DEFAULT(end(in, valid), valid, h);
+ valid = CHECK_DEFAULT(h.check_all(), valid, h);
+ return h;
+}
+
+std::pair<bool, MLGOHeuristics> parse_mlgo(std::istream &in)
+{
+ auto tokens = TokenStream(in);
+ bool valid = true;
+ auto h = mlgo(tokens, valid);
+ return std::make_pair(std::move(valid), std::move(h));
+}
+} // namespace parser
+} // namespace mlgo
+} // namespace arm_compute
+
+#undef CHECK
+#undef CHECK_DEFAULT
+#undef FAIL_WITH_MSG
+#undef FAIL_WITH_MSG_DEFAULT \ No newline at end of file
diff --git a/src/runtime/CL/mlgo/MLGOParser.h b/src/runtime/CL/mlgo/MLGOParser.h
new file mode 100644
index 0000000000..e4a31c1f55
--- /dev/null
+++ b/src/runtime/CL/mlgo/MLGOParser.h
@@ -0,0 +1,199 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_MLGO_MLGOPARSER_H
+#define SRC_MLGO_MLGOPARSER_H
+
+#include "src/runtime/CL/mlgo/MLGOHeuristics.h"
+
+#include <deque>
+#include <istream>
+#include <string>
+#include <utility>
+
+/** A DotMLGO file parser (LL(k) parser)
+ *
+ * The grammar of DotMLGO is defined as the following ENBF:
+ *
+ * delim = "," | "\n"; // Note that delimiters are omitted from the definition below
+ *
+ * mlgo = header, heuristics-table, {heuristic-tree};
+ *
+ * header = "<header>", gemm-version, ip-type, "</header>";
+ * gemm-version = "gemm-version", "[", int, int, int, "]";
+ * ip-type = "ip-type", ("gpu" | "cpu");
+ *
+ * heiristics-table = "<heuristics-table>", {heuristics-table-entry}, "</heuristics-table>";
+ * heuristics-table-entry = entry-id, ip-name, num-cores, data-type, gpu-priority, gpu-behavior, heuristic-type, free-vars;
+ * entry-id = int;
+ * ip-name = char-sequence;
+ * num-cores = int;
+ * data-type = "f32" | "f16" | "qasymm8";
+ * gpu-priority = "best-performance" | "best-memory-usage";
+ * gpu-behavior = "static" | "dynamic";
+ * heuristic-type = "gemm-type" | "gemm-config-native" | "gemm-config-reshaped-only-rhs" | "gemm-config-reshaped";
+ * free-vars = "[", {char-sequence}, "]";
+ *
+ * heuristic-tree = "<heuristic", entry-id, ">", {tree-node}, "</heuristic>";
+ * tree-node = branch-node | leaf-node;
+ * branch-node = "b", entry-id, lhs-type, lhs-value, conditional-op, rhs-type, rhs-value, true-node, false-node;
+ * lhs-type = comparator-type;
+ * lhs-value = comparator-value;
+ * rhs-type = comparator-type;
+ * rhs-value = comparator-value;
+ * comparator-type = "var" | "num" | "enum";
+ * comparator-value = char-sequence | float;
+ * conditional-op = "<" | "<=" | "==" | ">=" | ">";
+ * true-node = entry-id;
+ * false-node = entry-id;
+ * leaf-node = "l", entry-id, heuristic-type, leaf-value;
+ * leaf-value = gemm-type | gemm-config-native | gemm-config-reshaped-only-rhs | gemm-config-reshaped
+ * gemm-type = "native" | "reshaped-only-rhs" | "reshaped";
+ * gemm-config-native = "[", int, int, int, "]";
+ * gemm-config-reshaped-only-rhs = "[", int, int, int, int, bool, bool, bool, "]";
+ * gemm-config-reshaped = "[", int, int, int, int, int, bool, bool, bool, bool, "]";
+ */
+
+namespace arm_compute
+{
+namespace mlgo
+{
+namespace parser
+{
+/** Type of Token */
+enum class TokenType
+{
+ L_List = '[', /**< List open */
+ R_List = ']', /**< List close */
+ Int, /**< Integral */
+ Float, /**< Floating */
+ Text, /**< Text/String */
+ End, /**< End of stream */
+};
+
+struct CharPosition
+{
+ bool operator==(const CharPosition &other) const
+ {
+ return ln == other.ln && col == other.col;
+ }
+
+ size_t ln{ 0 };
+ size_t col{ 0 };
+};
+
+/** Token */
+struct Token
+{
+ Token(TokenType t, std::string v, CharPosition pos)
+ : type{ t }, value{ v }, pos{ pos }
+ {
+ }
+
+ bool operator==(const Token &other) const
+ {
+ return type == other.type && value == other.value && pos == other.pos;
+ }
+
+ TokenType type; /**< Token type */
+ std::string value; /**< Token value */
+ CharPosition pos;
+};
+
+/** A stream of token */
+class TokenStream
+{
+ // NOTE: _tokens is never empty. The end of token stream is signalled by the End Token
+public:
+ static constexpr size_t max_look_ahead = 10;
+
+public:
+ /** Constructor
+ *
+ * @param[in] s Input stream
+ * @param[in] delims Delimiter characters packed in a string. Each char from the string can be used as a delim on its own
+ */
+ TokenStream(std::istream &s, const std::string &delims = ",\n");
+
+ /** Check if there're more (non-End) Tokens
+ * @return true If there are more tokens
+ * @return false If reached end of stream (only End token)
+ */
+ explicit operator bool() const;
+
+ /** Get and pop off the current token
+ *
+ * @return Token
+ */
+ Token take();
+
+ /** Peek the next ith token
+ *
+ * @param[in] i The next ith token. i < @ref max_look_ahead.
+ *
+ * @return Token
+ */
+ Token peek(size_t i = 0);
+
+ /** Get the position of the current token
+ *
+ * @return CharPosition
+ */
+ CharPosition current_pos() const
+ {
+ return _tokens.front().pos;
+ }
+
+private:
+ void read();
+
+ Token recognize_tok(char ch);
+
+ Token num_st(std::string value = "");
+
+ Token float_after_dp_st(std::string value = "");
+
+ Token text_st(std::string value = "");
+
+ bool reached_end() const;
+
+ bool is_delim(char ch) const;
+
+ std::string _delims;
+ std::istream &_istream;
+ std::deque<Token> _tokens;
+ CharPosition _lookahead_pos;
+};
+
+/** Parse and construct a @ref MLGOHeuristics from input stream
+ *
+ * @param[in] in Input stream
+ *
+ * @return MLGOHeuristics
+ */
+std::pair<bool, MLGOHeuristics> parse_mlgo(std::istream &in);
+
+} // namespace parser
+} // namespace mlgo
+} // namespace arm_compute
+#endif //SRC_MLGO_MLGOPARSER_H \ No newline at end of file
diff --git a/src/runtime/CL/mlgo/Utils.cpp b/src/runtime/CL/mlgo/Utils.cpp
new file mode 100644
index 0000000000..bd06bdf521
--- /dev/null
+++ b/src/runtime/CL/mlgo/Utils.cpp
@@ -0,0 +1,143 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/runtime/CL/mlgo/Utils.h"
+
+namespace arm_compute
+{
+namespace mlgo
+{
+std::ostream &operator<<(std::ostream &os, const GEMMConfigNative &config)
+{
+ return os << "Native:{"
+ << "m0: " << config.m0 << ", "
+ << "n0: " << config.n0 << ", "
+ << "k0: " << config.k0 << ", "
+ << "}";
+}
+std::ostream &operator<<(std::ostream &os, const GEMMConfigReshapedOnlyRHS &config)
+{
+ return os << "ReshapedOnlyRHS:{"
+ << "m0: " << config.m0 << ", "
+ << "n0: " << config.n0 << ", "
+ << "k0: " << config.k0 << ", "
+ << "h0: " << config.h0 << ", "
+ << "interleave_rhs: " << config.interleave_rhs << ", "
+ << "transpose_rhs: " << config.transpose_rhs << ", "
+ << "export_cl_image: " << config.export_cl_image
+ << "}";
+}
+std::ostream &operator<<(std::ostream &os, const GEMMConfigReshaped &config)
+{
+ return os << "Reshaped:{"
+ << "m0: " << config.m0 << ", "
+ << "n0: " << config.n0 << ", "
+ << "k0: " << config.k0 << ", "
+ << "v0: " << config.v0 << ", "
+ << "h0: " << config.h0 << ", "
+ << "interleave_lhs: " << config.interleave_lhs << ", "
+ << "interleave_rhs: " << config.interleave_rhs << ", "
+ << "transpose_rhs: " << config.transpose_rhs << ", "
+ << "export_cl_image: " << config.export_cl_image
+ << "}";
+}
+std::ostream &operator<<(std::ostream &os, const HeuristicType &ht)
+{
+ switch(ht)
+ {
+ case HeuristicType::GEMM_Type:
+ {
+ os << "GEMM_Type";
+ break;
+ }
+ case HeuristicType::GEMM_Config_Reshaped_Only_RHS:
+ {
+ os << "GEMM_Config_Reshaped_Only_RHS";
+ break;
+ }
+ case HeuristicType::GEMM_Config_Reshaped:
+ {
+ os << "GEMM_Config_Reshaped";
+ break;
+ }
+ default:
+ {
+ os << "Unknown";
+ break;
+ }
+ }
+ return os;
+}
+std::ostream &operator<<(std::ostream &os, const DataType &dt)
+{
+ switch(dt)
+ {
+ case DataType::F32:
+ {
+ os << "F32";
+ break;
+ }
+ case DataType::F16:
+ {
+ os << "F16";
+ break;
+ }
+ case DataType::QASYMM8:
+ {
+ os << "QASYMM8";
+ break;
+ }
+ default:
+ {
+ os << "Unknown";
+ break;
+ }
+ }
+ return os;
+}
+std::ostream &operator<<(std::ostream &os, const HeuristicTree::Index &index)
+{
+ HeuristicType ht;
+ std::string ip;
+ DataType dt;
+ std::tie(ht, ip, dt) = index;
+ os << "Index(";
+ os << "HeuristicType=" << ht << ",";
+ os << "IP=" << ip << ",";
+ os << "DataType=" << dt;
+ os << ")";
+ return os;
+}
+
+namespace parser
+{
+std::ostream &operator<<(std::ostream &os, CharPosition pos)
+{
+ os << "(Ln: " << pos.ln << ", Col: " << pos.col << ")";
+ return os;
+}
+} // namespace parser
+
+} // namespace mlgo
+
+} // namespace arm_compute \ No newline at end of file
diff --git a/src/runtime/CL/mlgo/Utils.h b/src/runtime/CL/mlgo/Utils.h
new file mode 100644
index 0000000000..2e324dd439
--- /dev/null
+++ b/src/runtime/CL/mlgo/Utils.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_MLGO_UTILS_H
+#define SRC_MLGO_UTILS_H
+
+#include "src/runtime/CL/mlgo/Common.h"
+#include "src/runtime/CL/mlgo/HeuristicTree.h"
+#include "src/runtime/CL/mlgo/MLGOParser.h"
+
+#include <ostream>
+
+namespace arm_compute
+{
+namespace mlgo
+{
+std::ostream &operator<<(std::ostream &os, const GEMMConfigNative &config);
+std::ostream &operator<<(std::ostream &os, const GEMMConfigReshapedOnlyRHS &config);
+std::ostream &operator<<(std::ostream &os, const GEMMConfigReshaped &config);
+std::ostream &operator<<(std::ostream &os, const HeuristicType &ht);
+std::ostream &operator<<(std::ostream &os, const DataType &dt);
+std::ostream &operator<<(std::ostream &os, const HeuristicTree::Index &index);
+namespace parser
+{
+std::ostream &operator<<(std::ostream &os, CharPosition tok);
+} // namespace parser
+} // namespace mlgo
+} // namespace arm_compute
+
+#endif //SRC_MLGO_UTILS_H \ No newline at end of file
diff --git a/tests/SConscript b/tests/SConscript
index 92cf47b5c6..041ed8f548 100644
--- a/tests/SConscript
+++ b/tests/SConscript
@@ -284,4 +284,4 @@ if test_env['benchmark_examples']:
Depends(arm_compute_benchmark_examples, arm_compute_test_framework)
Depends(arm_compute_benchmark_examples, arm_compute_lib)
Default(arm_compute_benchmark_examples)
- Export('arm_compute_benchmark_examples')
+ Export('arm_compute_benchmark_examples') \ No newline at end of file
diff --git a/tests/validation/CL/UNIT/MLGOHeuristics.cpp b/tests/validation/CL/UNIT/MLGOHeuristics.cpp
new file mode 100644
index 0000000000..895b4b51d0
--- /dev/null
+++ b/tests/validation/CL/UNIT/MLGOHeuristics.cpp
@@ -0,0 +1,461 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/runtime/CL/mlgo/MLGOHeuristics.h"
+#include "src/runtime/CL/mlgo/Utils.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+
+using namespace arm_compute::mlgo;
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+TEST_SUITE(CL)
+TEST_SUITE(UNIT)
+TEST_SUITE(MLGOHeuristics)
+TEST_CASE(CorrectDotMLGOShouldLoadCorrectly, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
+ 1, g71 , 8, f16, best-performance, static, gemm-config-reshaped-only-rhs, [m,n,k,n]
+ 2, g76 , 8, f16, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b , 0, var, m, ==, num, 10., 1, 2
+ l , 1, gemm-type, reshaped
+ b , 2, var, r_mn, >=, num, 2., 3, 6
+ b , 3, var, n, >=, num, 200., 4, 5
+ l , 4, gemm-type, reshaped-only-rhs
+ l , 5, gemm-type, reshaped
+ l , 6, gemm-type, reshaped-only-rhs
+ </heuristic>
+ <heuristic, 1>
+ b ,0,var, n, >, num, 100., 1, 4
+ b ,1,var, r_mnk, <=, num, 20., 2, 3
+ l ,2,gemm-config-reshaped-only-rhs, [4, 4,4,2,1,0,1]
+ l ,3,gemm-config-reshaped-only-rhs,[ 2, 2,4,2,1,1, 1 ]
+ b ,4,var, n, >=, num, 199.12, 5, 6
+ l ,5,gemm-config-reshaped-only-rhs, [1, 4,3,4,0,0,0]
+ l ,6,gemm-config-reshaped-only-rhs, [5, 4,4,5,1,1,0]
+ </heuristic>
+ <heuristic, 2>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ heuristics.reload_from_stream(ss);
+
+ ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 10, 1024, 20, 1 }).second == GEMMType::RESHAPED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 201, 5, 1 }).second == GEMMType::RESHAPED_ONLY_RHS, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 200, 199, 16 }).second == GEMMType::RESHAPED_ONLY_RHS, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(heuristics.query_gemm_type(Query{ "g76", DataType::F32, 400, 199, 512, 4 }).second == GEMMType::RESHAPED, framework::LogLevel::ERRORS);
+
+ ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 100, 1024, 20, 32 }).second == GEMMConfigReshapedOnlyRHS{ 4, 4, 4, 2, true, false, true }),
+ framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 100, 1024, 20, 32 }).second == GEMMConfigReshapedOnlyRHS{ 4, 4, 4, 2, true, false, true }),
+ framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 128, 101, 20, 1 }).second == GEMMConfigReshapedOnlyRHS{ 2, 2, 4, 2, true, true, true }),
+ framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 400, 100, 512, 1 }).second == GEMMConfigReshapedOnlyRHS{ 5, 4, 4, 5, true, true, false }),
+ framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g71", DataType::F16, 400, 100, 512, 1 }).second == GEMMConfigReshapedOnlyRHS{ 5, 4, 4, 5, true, true, false }),
+ framework::LogLevel::ERRORS);
+
+ ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F16, 100, 100, 20, 32 }).second == GEMMConfigReshaped{ 4, 2, 4, 2, 8, true, false, true, false }),
+ framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT((heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F16, 128, 512, 1024, 1 }).second == GEMMConfigReshaped{ 4, 2, 4, 2, 8, true, false, true, false }),
+ framework::LogLevel::ERRORS);
+}
+
+TEST_CASE(InvalidDotmlgoSyntaxShouldReturnInvalidStatus, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,pu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heurist
+ <heuristic, 0>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+
+TEST_SUITE(InvalidDotmlgoSemanticsShouldReturnInvalidStatus)
+// If the semantics errors are local to some trees instead of the entire heuristics, an alternative is to simply
+// ignore/remove those invalid trees. However the reason why we choose to throw, thus invalidating the entire
+// heuristics is that if there are some invalid trees, the quality of the dotmlgo is called into question even if
+// the rest of the trees are semantically valid, and they could severely degrade the performance of GEMM. Therefore
+// this "all or nothing" approach when it comes to dotmlgo correctness is safer and more defensive.
+
+// Also note that the semantic error of the tree only refers to those that obstruct its evaluation and thus query,
+// (e.g. invalid tree structure, unsupported features etc.) instead of those affecting the desired outcome
+// (usually in terms of final GEMM performance, e.g. the effectiveness of the decision tree)
+
+// In the future we might want to check the content of the exceptions as well. But right now it suffices to only
+// know that it throws exactly when it needs to.
+TEST_CASE(MismatchesBetweenHeuristicsTableEntriesAndHeuristicTrees, framework::DatasetMode::ALL)
+{
+ {
+ // Mismatching number of entries 1
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ // NOTE: This case might throw an internal error as the tree inserted by the heuristics-table cannot not be checked
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+ }
+
+ {
+ // Mismatching number of entries 2
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ </heuristics-table>
+ <heuristic, 1>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+ }
+
+ {
+ // Mismatching info
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+ }
+}
+
+TEST_CASE(RepeatedHeuristicsTableEntriesId, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ 0, g71 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ <heuristic, 1>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+
+TEST_CASE(RepeatedHeuristicsTableEntriesIndex, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ 1, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ <heuristic, 1>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+
+TEST_CASE(RepeatedHeuristicTreesId, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ 1, g71 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ <heuristic, 0>
+ l ,0,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+TEST_CASE(EmptyTree, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+
+TEST_CASE(InvalidTreeMissingRoot, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b ,2, var, m, ==, num, 10., 3, 4
+ l ,3,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ l ,4,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+TEST_CASE(InvalidTreeMissingNodes, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b ,0, var, m, ==, num, 10., 1, 2
+ l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+TEST_CASE(InvalidTreeRepeatedNodeIds, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b ,0, var, m, ==, num, 10., 1, 2
+ l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ l ,1,gemm-config-reshaped,[1,2,4,2,8,1,0,1,0]
+ l ,2,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+TEST_CASE(InvalidTreeDisjointNodes, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b ,0, var, m, ==, num, 10., 1, 2
+ l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ l ,2,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
+
+ b ,4, var, n, ==, num, 10., 5, 6
+ l ,5,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ l ,6,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
+
+ l ,7,gemm-config-reshaped,[2,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+TEST_CASE(InvalidTreeLoop, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b ,0, var, m, ==, num, 10., 0, 1
+ l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+TEST_CASE(InvalidTreeCycle, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b ,0, var, m, ==, num, 10., 1, 5
+ b ,1, var, n, ==, num, 10., 2, 3
+ l ,2,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ b ,3, var, k, ==, num, 10., 0, 4
+ l ,4,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ l ,5,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+TEST_CASE(InvalidTreeInvalidFeatures, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-config-reshaped, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b ,0, var, magic_feature, ==, num, 10., 1, 2
+ l ,1,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ l ,2,gemm-config-reshaped,[4,2,4,2,8,1,0,1,0]
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(!heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+}
+TEST_SUITE_END() // InvalidDotmlgoSemanticsShouldReturnInvalidStatus
+
+TEST_CASE(InvalidUsageOfHeuristicsShouldReturnInvalidStatus, framework::DatasetMode::ALL)
+{
+ std::string mlgo_str = R"_(
+ <header>
+ gemm-version, [1,2,1]
+ ip-type,gpu
+ </header>
+ <heuristics-table>
+ 0, g76 , 8, f32, best-performance, static, gemm-type, [m,n,k,n]
+ </heuristics-table>
+ <heuristic, 0>
+ b , 0, var, m, ==, num, 10., 1, 2
+ l , 1, gemm-type, reshaped
+ b , 2, var, r_mn, >=, num, 2., 3, 6
+ b , 3, var, n, >=, num, 200., 4, 5
+ l , 4, gemm-type, reshaped-only-rhs
+ l , 5, gemm-type, reshaped
+ l , 6, gemm-type, reshaped-only-rhs
+ </heuristic>
+ )_";
+ std::stringstream ss(mlgo_str);
+ MLGOHeuristics heuristics;
+ ARM_COMPUTE_EXPECT(heuristics.reload_from_stream(ss), framework::LogLevel::ERRORS);
+
+ // Querying unavailable heuristic type should return invalid Status
+ ARM_COMPUTE_EXPECT(!heuristics.query_gemm_config_reshaped(Query{ "g76", DataType::F32, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
+ // Querying unavailable ip target should return invalid Status
+ ARM_COMPUTE_EXPECT(!heuristics.query_gemm_type(Query{ "g77", DataType::F32, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
+ // Querying unavailable data type should return invalid Status
+ ARM_COMPUTE_EXPECT(!heuristics.query_gemm_config_reshaped_only_rhs(Query{ "g76", DataType::QASYMM8, 1024, 1024, 100, 3 }).first, framework::LogLevel::ERRORS);
+}
+TEST_SUITE_END() // MLGOHeuristics
+TEST_SUITE_END() // UNIT
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute