aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceTest.hpp
blob: 40c9e5e597cd4cf5f684cbb0bdddaab50464f82b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/ArmNN.hpp>
#include <armnn/TypesUtils.hpp>
#include "InferenceModel.hpp"

#include <Logging.hpp>

#include <boost/log/core/core.hpp>
#include <boost/program_options.hpp>


namespace armnn
{

inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
{
    std::string token;
    in >> token;
    compute = armnn::ParseComputeDevice(token.c_str());
    if (compute == armnn::Compute::Undefined)
    {
        in.setstate(std::ios_base::failbit);
        throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
    }
    return in;
}

inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
{
    std::string token;
    in >> token;
    armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
    if (compute == armnn::Compute::Undefined)
    {
        in.setstate(std::ios_base::failbit);
        throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
    }
    backend = compute;
    return in;
}

namespace test
{

class TestFrameworkException : public Exception
{
public:
    using Exception::Exception;
};

struct InferenceTestOptions
{
    unsigned int m_IterationCount;
    std::string m_InferenceTimesFile;
    bool m_EnableProfiling;

    InferenceTestOptions()
        : m_IterationCount(0),
          m_EnableProfiling(0)
    {}
};

enum class TestCaseResult
{
    /// The test completed without any errors.
    Ok,
    /// The test failed (e.g. the prediction didn't match the validation file).
    /// This will eventually fail the whole program but the remaining test cases will still be run.
    Failed,
    /// The test failed with a fatal error. The remaining tests will not be run.
    Abort
};

class IInferenceTestCase
{
public:
    virtual ~IInferenceTestCase() {}

    virtual void Run() = 0;
    virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
};

class IInferenceTestCaseProvider
{
public:
    virtual ~IInferenceTestCaseProvider() {}

    virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
    virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) { return true; };
    virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
    virtual bool OnInferenceTestFinished() { return true; };
};

template <typename TModel>
class InferenceModelTestCase : public IInferenceTestCase
{
public:
    using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;

    InferenceModelTestCase(TModel& model,
                           unsigned int testCaseId,
                           const std::vector<TContainer>& inputs,
                           const std::vector<unsigned int>& outputSizes)
        : m_Model(model)
        , m_TestCaseId(testCaseId)
        , m_Inputs(std::move(inputs))
    {
        // Initialize output vector
        const size_t numOutputs = outputSizes.size();
        m_Outputs.reserve(numOutputs);

        for (size_t i = 0; i < numOutputs; i++)
        {
            m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
        }
    }

    virtual void Run() override
    {
        m_Model.Run(m_Inputs, m_Outputs);
    }

protected:
    unsigned int GetTestCaseId() const { return m_TestCaseId; }
    const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }

private:
    TModel&                 m_Model;
    unsigned int            m_TestCaseId;
    std::vector<TContainer> m_Inputs;
    std::vector<TContainer> m_Outputs;
};

template <typename TTestCaseDatabase, typename TModel>
class ClassifierTestCase : public InferenceModelTestCase<TModel>
{
public:
    ClassifierTestCase(int& numInferencesRef,
        int& numCorrectInferencesRef,
        const std::vector<unsigned int>& validationPredictions,
        std::vector<unsigned int>* validationPredictionsOut,
        TModel& model,
        unsigned int testCaseId,
        unsigned int label,
        std::vector<typename TModel::DataType> modelInput);

    virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;

private:
    unsigned int m_Label;
    InferenceModelInternal::QuantizationParams m_QuantizationParams;

    /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
    /// @{
    int& m_NumInferencesRef;
    int& m_NumCorrectInferencesRef;
    const std::vector<unsigned int>& m_ValidationPredictions;
    std::vector<unsigned int>* m_ValidationPredictionsOut;
    /// @}
};

template <typename TDatabase, typename InferenceModel>
class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
{
public:
    template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
    ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);

    virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
    virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
    virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
    virtual bool OnInferenceTestFinished() override;

private:
    void ReadPredictions();

    typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
    std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
                                                  typename InferenceModel::CommandLineOptions)> m_ConstructModel;
    std::unique_ptr<InferenceModel> m_Model;

    std::string m_DataDir;
    std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
    std::unique_ptr<TDatabase> m_Database;

    int m_NumInferences; // Referenced by test cases.
    int m_NumCorrectInferences; // Referenced by test cases.

    std::string m_ValidationFileIn;
    std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.

    std::string m_ValidationFileOut;
    std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
};

bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
    InferenceTestOptions& outParams);

bool ValidateDirectory(std::string& dir);

bool InferenceTest(const InferenceTestOptions& params,
    const std::vector<unsigned int>& defaultTestCaseIds,
    IInferenceTestCaseProvider& testCaseProvider);

template<typename TConstructTestCaseProvider>
int InferenceTestMain(int argc,
    char* argv[],
    const std::vector<unsigned int>& defaultTestCaseIds,
    TConstructTestCaseProvider constructTestCaseProvider);

template<typename TDatabase,
    typename TParser,
    typename TConstructDatabaseCallable>
int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
    const char* inputBindingName, const char* outputBindingName,
    const std::vector<unsigned int>& defaultTestCaseIds,
    TConstructDatabaseCallable constructDatabase,
    const armnn::TensorShape* inputTensorShape = nullptr);

} // namespace test
} // namespace armnn

#include "InferenceTest.inl"