aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/tuners/CLLWSList.h
blob: 7ce10ac22002cf1e5e7040a8ad354d4f68b08a5c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
/*
 * Copyright (c) 2019 Arm Limited.
 *
 * SPDX-License-Identifier: MIT
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
#ifndef ARM_COMPUTE_CL_LWS_LIST_H
#define ARM_COMPUTE_CL_LWS_LIST_H

#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/runtime/CL/CLTunerTypes.h"
#include "support/ToolchainSupport.h"
#include <memory>

namespace arm_compute
{
namespace cl_tuner
{
constexpr unsigned int max_lws_supported_x{ 64u };
constexpr unsigned int max_lws_supported_y{ 32u };
constexpr unsigned int max_lws_supported_z{ 32u };

/** Interface for LWS lists */
class ICLLWSList
{
public:
    /** Constructor */
    ICLLWSList() = default;
    /** Copy Constructor */
    ICLLWSList(const ICLLWSList &) = default;
    /** Move Constructor */
    ICLLWSList(ICLLWSList &&) noexcept(true) = default;
    /** Assignment */
    ICLLWSList &operator=(const ICLLWSList &) = default;
    /** Move Assignment */
    ICLLWSList &operator=(ICLLWSList &&) noexcept(true) = default;
    /** Destructor */
    virtual ~ICLLWSList() = default;

    /** Return the LWS value at the given index.
     *
     * @return LWS value at the given index
     */
    virtual cl::NDRange operator[](size_t) = 0;

    /** LWS list size.
     *
     * @return LWS list size
     */
    virtual size_t size() = 0;
};

/** Non instantiable base class for LWS combinations that use Index2Cooard mapping */
class CLLWSList : public ICLLWSList
{
protected:
    /* Shape of 3-D search space */
    TensorShape search_space_shape{ 0, 0, 0 };

    /** Constructor */
    CLLWSList() = default;
    /** Copy Constructor */
    CLLWSList(const CLLWSList &) = default;
    /** Move Constructor */
    CLLWSList(CLLWSList &&) noexcept(true) = default;
    /** Assignment */
    CLLWSList &operator=(const CLLWSList &) = default;
    /** Move Assignment */
    CLLWSList &operator=(CLLWSList &&) noexcept(true) = default;
    /** Destructor */
    virtual ~CLLWSList() = default;

    // Inherited methods overridden:
    virtual size_t size() override;
};

/** Exhaustive list of all possible LWS values */
class CLLWSListExhaustive : public CLLWSList
{
public:
    /** Prevent default constructor calls */
    CLLWSListExhaustive() = delete;
    /** Constructor */
    CLLWSListExhaustive(const cl::NDRange &gws);
    /** Copy Constructor */
    CLLWSListExhaustive(const CLLWSListExhaustive &) = default;
    /** Move Constructor */
    CLLWSListExhaustive(CLLWSListExhaustive &&) noexcept(true) = default;
    /** Assignment */
    CLLWSListExhaustive &operator=(const CLLWSListExhaustive &) = default;
    /** Move Assignment */
    CLLWSListExhaustive &operator=(CLLWSListExhaustive &&) noexcept(true) = default;
    /** Destructor */
    ~CLLWSListExhaustive() = default;

    // Inherited methods overridden:
    cl::NDRange operator[](size_t) override;
};

/** A subset of LWS values that are either factors of gws when gws[2] < 16 or power of 2 */
class CLLWSListNormal : public CLLWSList
{
public:
    /** Constructor */
    CLLWSListNormal(const cl::NDRange &gws);
    /** Copy Constructor */
    CLLWSListNormal(const CLLWSListNormal &) = default;
    /** Move Constructor */
    CLLWSListNormal(CLLWSListNormal &&) noexcept(true) = default;
    /** Assignment */
    CLLWSListNormal &operator=(const CLLWSListNormal &) = default;
    /** Move Assignment */
    CLLWSListNormal &operator=(CLLWSListNormal &&) noexcept(true) = default;
    /** Destructor */
    ~CLLWSListNormal() = default;

    // Inherited methods overridden:
    cl::NDRange operator[](size_t) override;

protected:
    std::vector<unsigned int> _lws_x{};
    std::vector<unsigned int> _lws_y{};
    std::vector<unsigned int> _lws_z{};

    /** Prevent default constructor calls */
    CLLWSListNormal() = default;

private:
    /** Utility function used to initialize the LWS values to test.
     *  Only the LWS values which are power of 2 or satisfy the modulo conditions with GWS are taken into account by the CLTuner
     *
     * @param[in, out] lws         Vector of LWS to test
     * @param[in]      gws         Size of the specific GWS
     * @param[in]      lws_max     Max LWS value allowed to be tested
     * @param[in]      mod_let_one True if the results of the modulo operation between gws and the lws can be less than one.
     */
    void initialize_lws_values(std::vector<unsigned int> &lws, unsigned int gws, unsigned int lws_max, bool mod_let_one);
};

/** A minimal subset of LWS values that only have 1,2 and 4/8 */
class CLLWSListRapid : public CLLWSListNormal
{
public:
    /** Prevent default constructor calls */
    CLLWSListRapid() = delete;
    /** Constructor */
    CLLWSListRapid(const cl::NDRange &gws);
    /** Copy Constructor */
    CLLWSListRapid(const CLLWSListRapid &) = default;
    /** Move Constructor */
    CLLWSListRapid(CLLWSListRapid &&) noexcept(true) = default;
    /** Assignment */
    CLLWSListRapid &operator=(const CLLWSListRapid &) = default;
    /** Move Assignment */
    CLLWSListRapid &operator=(CLLWSListRapid &&) noexcept(true) = default;
    /** Destructor */
    virtual ~CLLWSListRapid() = default;

private:
    /** Utility function used to initialize the LWS values to test.
     *  Only the LWS values that have 1,2 and 4/8 for each dimension are taken into account by the CLTuner
     *
     * @param[in, out] lws     Vector of LWS to test
     * @param[in]      lws_max Max LWS value allowed to be tested
     */
    void initialize_lws_values(std::vector<unsigned int> &lws, unsigned int lws_max);
};

/** Factory to construct an ICLLWSList object based on the CL tuner mode */
class CLLWSListFactory final
{
public:
    /** Construct an ICLLWSList object for the given tuner mode and gws configuration.
     *
     * @return unique_ptr to the requested ICLLWSList implementation.
     */
    static std::unique_ptr<ICLLWSList> get_lws_list(CLTunerMode mode, const cl::NDRange &gws)
    {
        switch(mode)
        {
            case CLTunerMode::EXHAUSTIVE:
                return arm_compute::support::cpp14::make_unique<CLLWSListExhaustive>(gws);
            case CLTunerMode::NORMAL:
                return arm_compute::support::cpp14::make_unique<CLLWSListNormal>(gws);
            case CLTunerMode::RAPID:
                return arm_compute::support::cpp14::make_unique<CLLWSListRapid>(gws);
            default:
                return nullptr;
        }
    }
};
} // namespace cl_tuner
} // namespace arm_compute
#endif /*ARM_COMPUTE_CL_LWS_LIST_H */