aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceTest.cpp
blob: ebbf5066e4ac6a695f9883518c6d9dc61bbfc519 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "InferenceTest.hpp"

#include <armnn/utility/Assert.hpp>
#include <Filesystem.hpp>

#include "../src/armnn/Profiling.hpp"
#include <cxxopts/cxxopts.hpp>

#include <fstream>
#include <iostream>
#include <iomanip>
#include <array>

using namespace std;
using namespace std::chrono;
using namespace armnn::test;

namespace armnn
{
namespace test
{
/// Parse the command line of an ArmNN (or referencetests) inference test program.
/// \return false if any error occurred during options processing, otherwise true
bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
    InferenceTestOptions& outParams)
{
    cxxopts::Options options("InferenceTest", "Inference iteration parameters");

    try
    {
        // Adds generic options needed for all inference tests.
        options
            .allow_unrecognised_options()
            .add_options()
                ("h,help", "Display help messages")
                ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
                 cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
                ("inference-times-file",
                 "If non-empty, each individual inference time will be recorded and output to this file",
                 cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
                ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
                 cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));

        std::vector<std::string> required; //to be passed as reference to derived inference tests

        // Adds options specific to the ITestCaseProvider.
        testCaseProvider.AddCommandLineOptions(options, required);

        auto result = options.parse(argc, argv);

        if (result.count("help"))
        {
            std::cout << options.help() << std::endl;
            return false;
        }

        CheckRequiredOptions(result, required);

    }
    catch (const cxxopts::OptionException& e)
    {
        std::cerr << e.what() << std::endl << options.help() << std::endl;
        return false;
    }
    catch (const std::exception& e)
    {
        ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
        std::cerr << "Fatal internal error: " << e.what() << std::endl;
        return false;
    }

    if (!testCaseProvider.ProcessCommandLineOptions(outParams))
    {
        return false;
    }

    return true;
}

bool ValidateDirectory(std::string& dir)
{
    if (dir.empty())
    {
        std::cerr << "No directory specified" << std::endl;
        return false;
    }

    if (dir[dir.length() - 1] != '/')
    {
        dir += "/";
    }

    if (!fs::exists(dir))
    {
        std::cerr << "Given directory " << dir << " does not exist" << std::endl;
        return false;
    }

    if (!fs::is_directory(dir))
    {
        std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
        return false;
    }

    return true;
}

bool InferenceTest(const InferenceTestOptions& params,
    const std::vector<unsigned int>& defaultTestCaseIds,
    IInferenceTestCaseProvider& testCaseProvider)
{
#if !defined (NDEBUG)
    if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
    {
        ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
    }
#endif

    double totalTime = 0;
    unsigned int nbProcessed = 0;
    bool success = true;

    // Opens the file to write inference times too, if needed.
    ofstream inferenceTimesFile;
    const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
    if (recordInferenceTimes)
    {
        inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
        if (!inferenceTimesFile.good())
        {
            ARMNN_LOG(error) << "Failed to open inference times file for writing: "
                << params.m_InferenceTimesFile;
            return false;
        }
    }

    // Create a profiler and register it for the current thread.
    std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>();
    ProfilerManager::GetInstance().RegisterProfiler(profiler.get());

    // Enable profiling if requested.
    profiler->EnableProfiling(params.m_EnableProfiling);

    // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
    std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
    if (warmupTestCase == nullptr)
    {
        ARMNN_LOG(error) << "Failed to load test case";
        return false;
    }

    try
    {
        warmupTestCase->Run();
    }
    catch (const TestFrameworkException& testError)
    {
        ARMNN_LOG(error) << testError.what();
        return false;
    }

    const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
        : static_cast<unsigned int>(defaultTestCaseIds.size());

    for (; nbProcessed < nbTotalToProcess; nbProcessed++)
    {
        const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
        std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);

        if (testCase == nullptr)
        {
            ARMNN_LOG(error) << "Failed to load test case";
            return false;
        }

        time_point<high_resolution_clock> predictStart;
        time_point<high_resolution_clock> predictEnd;

        TestCaseResult result = TestCaseResult::Ok;

        try
        {
            predictStart = high_resolution_clock::now();

            testCase->Run();

            predictEnd = high_resolution_clock::now();

            // duration<double> will convert the time difference into seconds as a double by default.
            double timeTakenS = duration<double>(predictEnd - predictStart).count();
            totalTime += timeTakenS;

            // Outputss inference times, if needed.
            if (recordInferenceTimes)
            {
                inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
            }

            result = testCase->ProcessResult(params);

        }
        catch (const TestFrameworkException& testError)
        {
            ARMNN_LOG(error) << testError.what();
            result = TestCaseResult::Abort;
        }

        switch (result)
        {
        case TestCaseResult::Ok:
            break;
        case TestCaseResult::Abort:
            return false;
        case TestCaseResult::Failed:
            // This test failed so we will fail the entire program eventually, but keep going for now.
            success = false;
            break;
        default:
            ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
            return false;
        }
    }

    const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;

    ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
        "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
    ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
        "Average time per test case: " << averageTimePerTestCaseMs << " ms";

    // if profiling is enabled print out the results
    if (profiler && profiler->IsProfilingEnabled())
    {
        profiler->Print(std::cout);
    }

    if (!success)
    {
        ARMNN_LOG(error) << "One or more test cases failed";
        return false;
    }

    return testCaseProvider.OnInferenceTestFinished();
}

} // namespace test

} // namespace armnn