aboutsummaryrefslogtreecommitdiff
path: root/samples/KeywordSpotting/src/Main.cpp
blob: 10efcd8ce7726f36a23086f6303f5c2383650646 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <iostream>
#include <map>
#include <vector>
#include <algorithm>
#include <cmath>
#include "KeywordSpottingPipeline.hpp"
#include "CmdArgsParser.hpp"
#include "ArmnnNetworkExecutor.hpp"
#include "AudioCapture.hpp"

const std::string AUDIO_FILE_PATH = "--audio-file-path";
const std::string MODEL_FILE_PATH = "--model-file-path";
const std::string LABEL_PATH = "--label-path";
const std::string PREFERRED_BACKENDS = "--preferred-backends";
const std::string HELP = "--help";

/*
 * The accepted options for this Speech Recognition executable
 */
static std::map<std::string, std::string> CMD_OPTIONS = 
{
        {AUDIO_FILE_PATH,    "[REQUIRED] Path to the Audio file to run speech recognition on"},
        {MODEL_FILE_PATH,    "[REQUIRED] Path to the Speech Recognition model to use"},
        {PREFERRED_BACKENDS, "[OPTIONAL] Takes the preferred backends in preference order, separated by comma."
                             " For example: CpuAcc,GpuAcc,CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]."
                             " Defaults to CpuAcc,CpuRef"}
};

/*
 * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector
 */
std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends) 
{
    std::vector<armnn::BackendId> backends;
    std::stringstream ss(preferredBackends);

    while (ss.good()) 
    {
        std::string backend;
        std::getline(ss, backend, ',');
        backends.emplace_back(backend);
    }
    return backends;
}

//Labels for this model
std::map<int, std::string> labels = 
{
        {0,  "silence"},
        {1,  "unknown"},
        {2,  "yes"},
        {3,  "no"},
        {4,  "up"},
        {5,  "down"},
        {6,  "left"},
        {7,  "right"},
        {8,  "on"},
        {9,  "off"},
        {10, "stop"},
        {11, "go"}
};


int main(int argc, char* argv[]) 
{
    printf("ArmNN major version: %d\n", ARMNN_MAJOR_VERSION);
    std::map<std::string, std::string> options;

    //Read command line args
    int result = ParseOptions(options, CMD_OPTIONS, argv, argc);
    if (result != 0) 
    {
        return result;
    }

    // Create the ArmNN inference runner
    common::PipelineOptions pipelineOptions;
    pipelineOptions.m_ModelName = "DS_CNN_CLUSTERED_INT8";
    pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH);
    if (CheckOptionSpecified(options, PREFERRED_BACKENDS)) 
    {
        pipelineOptions.m_backends = GetPreferredBackendList(
            (GetSpecifiedOption(options, PREFERRED_BACKENDS)));
    } 
    else 
    {
        pipelineOptions.m_backends = {"CpuAcc", "CpuRef"};
    }

    kws::IPipelinePtr kwsPipeline = kws::CreatePipeline(pipelineOptions);

    //Extract audio data from sound file
    auto filePath = GetSpecifiedOption(options, AUDIO_FILE_PATH);
    std::vector<float> audioData = audio::AudioCapture::LoadAudioFile(filePath);

    audio::AudioCapture capture;
    //todo: read samples and stride from pipeline
    capture.InitSlidingWindow(audioData.data(), 
                              audioData.size(), 
                              kwsPipeline->getInputSamplesSize(), 
                              kwsPipeline->getInputSamplesSize()/2);

    //Loop through audio data buffer
    while (capture.HasNext()) 
    {
        std::vector<float> audioBlock = capture.Next();
        common::InferenceResults<int8_t> results;

        //Prepare input tensors
        std::vector<int8_t> preprocessedData = kwsPipeline->PreProcessing(audioBlock);
        //Run inference
        kwsPipeline->Inference(preprocessedData, results);
        //Decode output
        kwsPipeline->PostProcessing(results, labels,
                                    [](int index, std::string& label, float prob) -> void {
                                        printf("Keyword \"%s\", index %d:, probability %f\n",
                                               label.c_str(),
                                               index,
                                               prob);
                                    });
    }

    return 0;
}