diff options
author | Éanna Ó Catháin <eanna.ocathain@arm.com> | 2021-04-07 14:35:25 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2021-05-07 09:11:52 +0000 |
commit | c6ab02a626e15b4a12fc09ecd844eb8b95380c3c (patch) | |
tree | 9912ed9cdb89cdb24483b22d6621ae30049ae321 /samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp | |
parent | e813d67f86df41a238ff79b5c554ef5027f56576 (diff) | |
download | armnn-c6ab02a626e15b4a12fc09ecd844eb8b95380c3c.tar.gz |
MLECO-1252 ASR sample application using the public ArmNN C++ API.
Change-Id: I98cd505b8772a8c8fa88308121bc94135bb45068
Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Diffstat (limited to 'samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp')
-rw-r--r-- | samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp | 139 |
1 files changed, 139 insertions, 0 deletions
diff --git a/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp b/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp new file mode 100644 index 0000000000..47ce30416f --- /dev/null +++ b/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp @@ -0,0 +1,139 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ArmnnNetworkExecutor.hpp" +#include "Decoder.hpp" +#include "MFCC.hpp" +#include "Preprocess.hpp" + +namespace asr +{ +/** + * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference + * result post-processing. + * + */ +class ASRPipeline +{ +public: + + /** + * Creates speech recognition pipeline with given network executor and decoder. + * @param executor - unique pointer to inference runner + * @param decoder - unique pointer to inference results decoder + */ + ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor, + std::unique_ptr<Decoder> decoder); + + /** + * @brief Standard audio pre-processing implementation. + * + * Preprocesses and prepares the data for inference by + * extracting the MFCC features. + + * @param[in] audio - the raw audio data + * @param[out] preprocessor - the preprocessor object, which handles the data prepreration + */ + template<typename Tin,typename Tout> + std::vector<Tout> PreProcessing(std::vector<Tin>& audio, Preprocess& preprocessor) + { + int audioDataToPreProcess = preprocessor._m_windowLen + + ((preprocessor._m_mfcc._m_params.m_numMfccVectors -1) *preprocessor._m_windowStride); + int outputBufferSize = preprocessor._m_mfcc._m_params.m_numMfccVectors + * preprocessor._m_mfcc._m_params.m_numMfccFeatures * 3; + std::vector<Tout> outputBuffer(outputBufferSize); + preprocessor.Invoke(audio.data(), audioDataToPreProcess, outputBuffer, m_executor->GetQuantizationOffset(), + m_executor->GetQuantizationScale()); + return outputBuffer; + } + + /** + * @brief Executes inference + * + * Calls inference runner provided during instance construction. + * + * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor. + * @param[out] result - raw inference results. + */ + template<typename T> + void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result) + { + size_t data_bytes = sizeof(std::vector<T>) + (sizeof(T) * preprocessedData.size()); + m_executor->Run(preprocessedData.data(), data_bytes, result); + } + + /** + * @brief Standard inference results post-processing implementation. + * + * Decodes inference results using decoder provided during construction. + * + * @param[in] inferenceResult - inference results to be decoded. + * @param[in] isFirstWindow - for checking if this is the first window of the sliding window. + * @param[in] isLastWindow - for checking if this is the last window of the sliding window. + * @param[in] currentRContext - the right context of the output text. To be output if it is the last window. + */ + template<typename T> + void PostProcessing(common::InferenceResults<int8_t>& inferenceResult, + bool& isFirstWindow, + bool isLastWindow, + std::string currentRContext) + { + int rowLength = 29; + int middleContextStart = 49; + int middleContextEnd = 99; + int leftContextStart = 0; + int rightContextStart = 100; + int rightContextEnd = 148; + + std::vector<T> contextToProcess; + + // If isFirstWindow we keep the left context of the output + if(isFirstWindow) + { + std::vector<T> chunk(&inferenceResult[0][leftContextStart], + &inferenceResult[0][middleContextEnd * rowLength]); + contextToProcess = chunk; + } + // Else we only keep the middle context of the output + else + { + std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength], + &inferenceResult[0][middleContextEnd * rowLength]); + contextToProcess = chunk; + } + std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess); + isFirstWindow = false; + std::cout << output << std::flush; + + // If this is the last window, we print the right context of the output + if(isLastWindow) + { + std::vector<T> rContext(&inferenceResult[0][rightContextStart*rowLength], + &inferenceResult[0][rightContextEnd * rowLength]); + currentRContext = this->m_decoder->DecodeOutput(rContext); + std::cout << currentRContext << std::endl; + } + } + +protected: + std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor; + std::unique_ptr<Decoder> m_decoder; +}; + +using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>; + +/** + * Constructs speech recognition pipeline based on configuration provided. + * + * @param[in] config - speech recognition pipeline configuration. + * @param[in] labels - asr labels + * + * @return unique pointer to asr pipeline. + */ +IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels); + +}// namespace asr
\ No newline at end of file |