diff options
Diffstat (limited to 'samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp')
-rw-r--r-- | samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp | 59 |
1 files changed, 28 insertions, 31 deletions
diff --git a/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp b/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp index 47ce30416f..bc3fbfe151 100644 --- a/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp +++ b/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp @@ -8,16 +8,16 @@ #include "ArmnnNetworkExecutor.hpp" #include "Decoder.hpp" #include "MFCC.hpp" -#include "Preprocess.hpp" +#include "Wav2LetterPreprocessor.hpp" -namespace asr +namespace asr { /** * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference * result post-processing. * */ -class ASRPipeline +class ASRPipeline { public: @@ -27,7 +27,7 @@ public: * @param decoder - unique pointer to inference results decoder */ ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor, - std::unique_ptr<Decoder> decoder); + std::unique_ptr<Decoder> decoder, std::unique_ptr<Wav2LetterPreprocessor> preprocessor); /** * @brief Standard audio pre-processing implementation. @@ -36,20 +36,16 @@ public: * extracting the MFCC features. * @param[in] audio - the raw audio data - * @param[out] preprocessor - the preprocessor object, which handles the data prepreration + * @param[out] preprocessor - the preprocessor object, which handles the data preparation */ - template<typename Tin,typename Tout> - std::vector<Tout> PreProcessing(std::vector<Tin>& audio, Preprocess& preprocessor) - { - int audioDataToPreProcess = preprocessor._m_windowLen + - ((preprocessor._m_mfcc._m_params.m_numMfccVectors -1) *preprocessor._m_windowStride); - int outputBufferSize = preprocessor._m_mfcc._m_params.m_numMfccVectors - * preprocessor._m_mfcc._m_params.m_numMfccFeatures * 3; - std::vector<Tout> outputBuffer(outputBufferSize); - preprocessor.Invoke(audio.data(), audioDataToPreProcess, outputBuffer, m_executor->GetQuantizationOffset(), - m_executor->GetQuantizationScale()); - return outputBuffer; - } + std::vector<int8_t> PreProcessing(std::vector<float>& audio); + + int getInputSamplesSize(); + int getSlidingWindowOffset(); + + // Exposing hardcoded constant as it can only be derived from model knowledge and not from model itself + // Will need to be refactored so that hard coded values are not defined outside of model settings + int SLIDING_WINDOW_OFFSET; /** * @brief Executes inference @@ -60,9 +56,9 @@ public: * @param[out] result - raw inference results. */ template<typename T> - void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result) + void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result) { - size_t data_bytes = sizeof(std::vector<T>) + (sizeof(T) * preprocessedData.size()); + size_t data_bytes = sizeof(T) * preprocessedData.size(); m_executor->Run(preprocessedData.data(), data_bytes, result); } @@ -78,9 +74,9 @@ public: */ template<typename T> void PostProcessing(common::InferenceResults<int8_t>& inferenceResult, - bool& isFirstWindow, - bool isLastWindow, - std::string currentRContext) + bool& isFirstWindow, + bool isLastWindow, + std::string currentRContext) { int rowLength = 29; int middleContextStart = 49; @@ -92,17 +88,17 @@ public: std::vector<T> contextToProcess; // If isFirstWindow we keep the left context of the output - if(isFirstWindow) + if (isFirstWindow) { std::vector<T> chunk(&inferenceResult[0][leftContextStart], - &inferenceResult[0][middleContextEnd * rowLength]); + &inferenceResult[0][middleContextEnd * rowLength]); contextToProcess = chunk; } - // Else we only keep the middle context of the output - else + else { + // Else we only keep the middle context of the output std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength], - &inferenceResult[0][middleContextEnd * rowLength]); + &inferenceResult[0][middleContextEnd * rowLength]); contextToProcess = chunk; } std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess); @@ -110,10 +106,10 @@ public: std::cout << output << std::flush; // If this is the last window, we print the right context of the output - if(isLastWindow) + if (isLastWindow) { - std::vector<T> rContext(&inferenceResult[0][rightContextStart*rowLength], - &inferenceResult[0][rightContextEnd * rowLength]); + std::vector<T> rContext(&inferenceResult[0][rightContextStart * rowLength], + &inferenceResult[0][rightContextEnd * rowLength]); currentRContext = this->m_decoder->DecodeOutput(rContext); std::cout << currentRContext << std::endl; } @@ -122,6 +118,7 @@ public: protected: std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor; std::unique_ptr<Decoder> m_decoder; + std::unique_ptr<Wav2LetterPreprocessor> m_preProcessor; }; using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>; @@ -136,4 +133,4 @@ using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>; */ IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels); -}// namespace asr
\ No newline at end of file +} // namespace asr
\ No newline at end of file |