aboutsummaryrefslogtreecommitdiff
path: root/samples/ObjectDetection/src/Main.cpp
blob: 8bc2f0de381a06024d840e2bc42d08d7d3cd1629 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "CvVideoFrameReader.hpp"
#include "CvWindowOutput.hpp"
#include "CvVideoFileWriter.hpp"
#include "ObjectDetectionPipeline.hpp"
#include "CmdArgsParser.hpp"

#include <fstream>
#include <iostream>
#include <map>
#include <random>

const std::string MODEL_NAME = "--model-name";
const std::string VIDEO_FILE_PATH = "--video-file-path";
const std::string MODEL_FILE_PATH = "--model-file-path";
const std::string OUTPUT_VIDEO_FILE_PATH = "--output-video-file-path";
const std::string LABEL_PATH = "--label-path";
const std::string PREFERRED_BACKENDS = "--preferred-backends";
const std::string PROFILING_ENABLED = "--profiling_enabled";
const std::string HELP = "--help";

/*
 * The accepted options for this Object detection executable
 */
static std::map<std::string, std::string> CMD_OPTIONS = {
        {VIDEO_FILE_PATH, "[REQUIRED] Path to the video file to run object detection on"},
        {MODEL_FILE_PATH, "[REQUIRED] Path to the Object Detection model to use"},
        {LABEL_PATH, "[REQUIRED] Path to the label set for the provided model file. "
                     "Label file  should be an ordered list, separated by a new line."},
        {MODEL_NAME, "[REQUIRED] The name of the model being used. Accepted options: YOLO_V3_TINY, SSD_MOBILE"},
        {OUTPUT_VIDEO_FILE_PATH, "[OPTIONAL] Path to the output video file with detections added in. "
                                 "If specified will save file to disk, else displays the output to screen"},
        {PREFERRED_BACKENDS, "[OPTIONAL] Takes the preferred backends in preference order, separated by comma."
                             " For example: CpuAcc,GpuAcc,CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]."
                             " Defaults to CpuAcc,CpuRef"},
        {PROFILING_ENABLED, "[OPTIONAL] Enabling this option will print important ML related milestones timing"
                            "information in micro-seconds. By default, this option is disabled."
                            "Accepted options are true/false."}
};

/*
 * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector
 */
std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends)
{
    std::vector<armnn::BackendId> backends;
    std::stringstream ss(preferredBackends);

    while(ss.good())
    {
        std::string backend;
        std::getline( ss, backend, ',' );
        backends.emplace_back(backend);
    }
    return backends;
}

/*
 * Assigns a color to each label in the label set
 */
std::vector<std::tuple<std::string, common::BBoxColor>> AssignColourToLabel(const std::string& pathToLabelFile)
{
    std::ifstream in(pathToLabelFile);
    std::vector<std::tuple<std::string, common::BBoxColor>> labels;

    std::string str;
    std::default_random_engine generator;
    std::uniform_int_distribution<int> distribution(0,255);

    while (std::getline(in, str))
    {
        if(!str.empty())
        {
            common::BBoxColor c{
                .colorCode = std::make_tuple(distribution(generator),
                                             distribution(generator),
                                             distribution(generator))
            };
            auto bboxInfo = std::make_tuple (str, c);

            labels.emplace_back(bboxInfo);
        }
    }
    return labels;
}

std::tuple<std::unique_ptr<common::IFrameReader<cv::Mat>>,
           std::unique_ptr<common::IFrameOutput<cv::Mat>>>
           GetFrameSourceAndSink(const std::map<std::string, std::string>& options) {

    std::unique_ptr<common::IFrameReader<cv::Mat>> readerPtr;

    std::unique_ptr<common::CvVideoFrameReader> reader = std::make_unique<common::CvVideoFrameReader>();
    reader->Init(GetSpecifiedOption(options, VIDEO_FILE_PATH));

    auto enc = reader->GetSourceEncodingInt();
    auto fps = reader->GetSourceFps();
    auto w = reader->GetSourceWidth();
    auto h = reader->GetSourceHeight();
    if (!reader->ConvertToRGB())
    {
        readerPtr = std::move(std::make_unique<common::CvVideoFrameReaderRgbWrapper>(std::move(reader)));
    }
    else
    {
        readerPtr = std::move(reader);
    }

    if(CheckOptionSpecified(options, OUTPUT_VIDEO_FILE_PATH))
    {
        std::string outputVideo = GetSpecifiedOption(options, OUTPUT_VIDEO_FILE_PATH);
        auto writer = std::make_unique<common::CvVideoFileWriter>();
        writer->Init(outputVideo, enc, fps, w, h);

        return std::make_tuple<>(std::move(readerPtr), std::move(writer));
    }
    else
    {
        auto writer = std::make_unique<common::CvWindowOutput>();
        writer->Init("Processed Video");
        return std::make_tuple<>(std::move(readerPtr), std::move(writer));
    }
}

int main(int argc, char *argv[])
{
    std::map<std::string, std::string> options;

    int result = ParseOptions(options, CMD_OPTIONS, argv, argc);
    if (result != 0)
    {
        return result;
    }

    // Create the network options
    common::PipelineOptions pipelineOptions;
    pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH);
    pipelineOptions.m_ModelName = GetSpecifiedOption(options, MODEL_NAME);

    if (CheckOptionSpecified(options, PROFILING_ENABLED))
    {
        pipelineOptions.m_ProfilingEnabled = GetSpecifiedOption(options, PROFILING_ENABLED) == "true";
    }
    if(CheckOptionSpecified(options, PREFERRED_BACKENDS))
    {
        pipelineOptions.m_backends = GetPreferredBackendList((GetSpecifiedOption(options, PREFERRED_BACKENDS)));
    }
    else
    {
        pipelineOptions.m_backends = {"CpuAcc", "CpuRef"};
    }

    auto labels = AssignColourToLabel(GetSpecifiedOption(options, LABEL_PATH));

    common::Profiling profiling(pipelineOptions.m_ProfilingEnabled);
    profiling.ProfilingStart();
    od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(pipelineOptions);

    auto inputAndOutput = GetFrameSourceAndSink(options);
    std::unique_ptr<common::IFrameReader<cv::Mat>> reader = std::move(std::get<0>(inputAndOutput));
    std::unique_ptr<common::IFrameOutput<cv::Mat>> sink = std::move(std::get<1>(inputAndOutput));

    if (!sink->IsReady())
    {
        std::cerr << "Failed to open video writer.";
        return 1;
    }

    common::InferenceResults<float> results;

    std::shared_ptr<cv::Mat> frame = reader->ReadFrame();

    //pre-allocate frames
    cv::Mat processed;

    while(!reader->IsExhausted(frame))
    {
        objectDetectionPipeline->PreProcessing(*frame, processed);
        objectDetectionPipeline->Inference(processed, results);
        objectDetectionPipeline->PostProcessing(results,
                                                [&frame, &labels](od::DetectedObjects detects) -> void {
            AddInferenceOutputToFrame(detects, *frame, labels);
        });

        sink->WriteFrame(frame);
        frame = reader->ReadFrame();
    }
    sink->Close();
    profiling.ProfilingStopAndPrintUs("Overall compute time");
    return 0;
}