// // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "ArmnnNetworkExecutor.hpp" #include "Decoder.hpp" #include "MFCC.hpp" #include "Preprocess.hpp" namespace asr { /** * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference * result post-processing. * */ class ASRPipeline { public: /** * Creates speech recognition pipeline with given network executor and decoder. * @param executor - unique pointer to inference runner * @param decoder - unique pointer to inference results decoder */ ASRPipeline(std::unique_ptr> executor, std::unique_ptr decoder); /** * @brief Standard audio pre-processing implementation. * * Preprocesses and prepares the data for inference by * extracting the MFCC features. * @param[in] audio - the raw audio data * @param[out] preprocessor - the preprocessor object, which handles the data prepreration */ template std::vector PreProcessing(std::vector& audio, Preprocess& preprocessor) { int audioDataToPreProcess = preprocessor._m_windowLen + ((preprocessor._m_mfcc._m_params.m_numMfccVectors -1) *preprocessor._m_windowStride); int outputBufferSize = preprocessor._m_mfcc._m_params.m_numMfccVectors * preprocessor._m_mfcc._m_params.m_numMfccFeatures * 3; std::vector outputBuffer(outputBufferSize); preprocessor.Invoke(audio.data(), audioDataToPreProcess, outputBuffer, m_executor->GetQuantizationOffset(), m_executor->GetQuantizationScale()); return outputBuffer; } /** * @brief Executes inference * * Calls inference runner provided during instance construction. * * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor. * @param[out] result - raw inference results. */ template void Inference(const std::vector& preprocessedData, common::InferenceResults& result) { size_t data_bytes = sizeof(std::vector) + (sizeof(T) * preprocessedData.size()); m_executor->Run(preprocessedData.data(), data_bytes, result); } /** * @brief Standard inference results post-processing implementation. * * Decodes inference results using decoder provided during construction. * * @param[in] inferenceResult - inference results to be decoded. * @param[in] isFirstWindow - for checking if this is the first window of the sliding window. * @param[in] isLastWindow - for checking if this is the last window of the sliding window. * @param[in] currentRContext - the right context of the output text. To be output if it is the last window. */ template void PostProcessing(common::InferenceResults& inferenceResult, bool& isFirstWindow, bool isLastWindow, std::string currentRContext) { int rowLength = 29; int middleContextStart = 49; int middleContextEnd = 99; int leftContextStart = 0; int rightContextStart = 100; int rightContextEnd = 148; std::vector contextToProcess; // If isFirstWindow we keep the left context of the output if(isFirstWindow) { std::vector chunk(&inferenceResult[0][leftContextStart], &inferenceResult[0][middleContextEnd * rowLength]); contextToProcess = chunk; } // Else we only keep the middle context of the output else { std::vector chunk(&inferenceResult[0][middleContextStart * rowLength], &inferenceResult[0][middleContextEnd * rowLength]); contextToProcess = chunk; } std::string output = this->m_decoder->DecodeOutput(contextToProcess); isFirstWindow = false; std::cout << output << std::flush; // If this is the last window, we print the right context of the output if(isLastWindow) { std::vector rContext(&inferenceResult[0][rightContextStart*rowLength], &inferenceResult[0][rightContextEnd * rowLength]); currentRContext = this->m_decoder->DecodeOutput(rContext); std::cout << currentRContext << std::endl; } } protected: std::unique_ptr> m_executor; std::unique_ptr m_decoder; }; using IPipelinePtr = std::unique_ptr; /** * Constructs speech recognition pipeline based on configuration provided. * * @param[in] config - speech recognition pipeline configuration. * @param[in] labels - asr labels * * @return unique pointer to asr pipeline. */ IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map& labels); }// namespace asr