aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/mlgo/HeuristicTree.h
blob: a4f8c116b9d23c5b7a170b59723de41f1af74dac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
/*
 * Copyright (c) 2021 Arm Limited.
 *
 * SPDX-License-Identifier: MIT
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
#ifndef SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H
#define SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H

#include "arm_compute/core/Types.h"

#include "src/runtime/CL/mlgo/Common.h"

#include <map>
#include <memory>
#include <string>
#include <utility>

namespace arm_compute
{
namespace mlgo
{
/** Conditional ops */
enum class ConditionalOp
{
    EQ, /**< Equal */
    LT, /**< Less than */
    LE, /**< Less than or equal to */
    GT, /**< Greater than */
    GE, /**< Greater than or equal to */
};

/** A branch condition expression evaluating: feature op threshold */
struct Condition
{
    std::string   feature;   /**< Feature name */
    ConditionalOp op;        /**< Condtional op */
    float         threshold; /**< Threshold value */
};

/** GEMM Shape used for query */
struct GEMMShape
{
    unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
    unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
    unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
    unsigned int b; /**< Batch size */
};

/** A binary decision tree based heuristic */
class HeuristicTree
{
public:
    using NodeID = size_t;
    using TreeID = size_t;
    using Index  = std::tuple<HeuristicType, std::string, DataType>;
    enum class NodeType
    {
        Branch,
        Leaf
    };
    struct Node
    {
        virtual NodeType type() const = 0;
        virtual ~Node()               = default;
    };

    struct BranchNode : public Node
    {
        BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
            : id{id}, condition{cond}, true_node{t_node}, false_node{f_node}
        {
        }
        NodeType type() const override
        {
            return NodeType::Branch;
        }
        NodeID    id;
        Condition condition;
        NodeID    true_node;
        NodeID    false_node;
    };

    template <typename T>
    struct LeafNode : public Node
    {
        LeafNode(NodeID id, T val) : id{id}, value{val}
        {
        }
        NodeType type() const override
        {
            return NodeType::Leaf;
        }
        NodeID id;
        T      value;
    };

public:
    /** Constructor */
    HeuristicTree();
    /** Constructor */
    HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type);
    // Since the HeuristicTree is a handle that owns the the nodes, it is move-only
    /** Prevent copy construction */
    HeuristicTree(const HeuristicTree &) = delete;
    /** Prevent copy assignment */
    HeuristicTree &operator=(const HeuristicTree &) = delete;
    /** Move constructor */
    HeuristicTree(HeuristicTree &&other) noexcept = default;
    /** Move assignment */
    HeuristicTree &operator=(HeuristicTree &&other) = default;

    /** Query a leaf value given a gemm shape
     *
     * @tparam T           Leaf value type
     * @param shape        A @ref GEMMShape for the query
     * @return std::pair<bool, T> Outcome contains bool, signalling if the query succeeded or not
     */
    template <typename T>
    std::pair<bool, T> query(GEMMShape shape) const;

    /** Add a leaf node
     *
     * @tparam T  Leaf value type
     * @param id  Leaf node ID
     * @param leaf_value Leaf node value
     * @return bool  If the addition succeeded or not
     */
    template <typename T>
    bool add_leaf(NodeID id, T leaf_value);
    /** Add a branch node
     *
     * @param id  Branch node ID
     * @param cond Branch node @ref Condition
     * @param true_node  True node's ID
     * @param false_node  False node's ID
     * @return bool  If the addition succeeded or not
     */
    bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node);

    /** Get tree ID
     * @return TreeID
     */
    TreeID id() const
    {
        return _id;
    }

    /** Get tree index
     * @return Index
     */
    Index index() const
    {
        return std::make_tuple(_heuristic_type, _ip_target, _data_type);
    }

    /** Check if tree is valid
     * @return bool
     */
    bool check();

private:
    static constexpr size_t _max_query_depth{1000}; // Maximum depth of query
    static constexpr size_t _max_num_nodes{100000}; // Maximum number of nodes contained by the tree
    static constexpr NodeID _root{0};               // Root tree ID

private:
    bool check_if_structurally_correct() const;

private:
    TreeID                                  _id;             /**< Heuristic tree ID */
    HeuristicType                           _heuristic_type; /**< Heuristic type */
    std::string                             _ip_target;      /**< IP target associated with the tree */
    DataType                                _data_type;      /**< Data type associated with the tree */
    std::map<NodeID, std::unique_ptr<Node>> _tree;           /**< Tree representation */
};
} // namespace mlgo

} // namespace arm_compute

#endif //SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H