aboutsummaryrefslogtreecommitdiff
path: root/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp')
-rw-r--r--samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp59
1 files changed, 28 insertions, 31 deletions
diff --git a/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp b/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp
index 47ce30416f..bc3fbfe151 100644
--- a/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp
+++ b/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp
@@ -8,16 +8,16 @@
#include "ArmnnNetworkExecutor.hpp"
#include "Decoder.hpp"
#include "MFCC.hpp"
-#include "Preprocess.hpp"
+#include "Wav2LetterPreprocessor.hpp"
-namespace asr
+namespace asr
{
/**
* Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference
* result post-processing.
*
*/
-class ASRPipeline
+class ASRPipeline
{
public:
@@ -27,7 +27,7 @@ public:
* @param decoder - unique pointer to inference results decoder
*/
ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
- std::unique_ptr<Decoder> decoder);
+ std::unique_ptr<Decoder> decoder, std::unique_ptr<Wav2LetterPreprocessor> preprocessor);
/**
* @brief Standard audio pre-processing implementation.
@@ -36,20 +36,16 @@ public:
* extracting the MFCC features.
* @param[in] audio - the raw audio data
- * @param[out] preprocessor - the preprocessor object, which handles the data prepreration
+ * @param[out] preprocessor - the preprocessor object, which handles the data preparation
*/
- template<typename Tin,typename Tout>
- std::vector<Tout> PreProcessing(std::vector<Tin>& audio, Preprocess& preprocessor)
- {
- int audioDataToPreProcess = preprocessor._m_windowLen +
- ((preprocessor._m_mfcc._m_params.m_numMfccVectors -1) *preprocessor._m_windowStride);
- int outputBufferSize = preprocessor._m_mfcc._m_params.m_numMfccVectors
- * preprocessor._m_mfcc._m_params.m_numMfccFeatures * 3;
- std::vector<Tout> outputBuffer(outputBufferSize);
- preprocessor.Invoke(audio.data(), audioDataToPreProcess, outputBuffer, m_executor->GetQuantizationOffset(),
- m_executor->GetQuantizationScale());
- return outputBuffer;
- }
+ std::vector<int8_t> PreProcessing(std::vector<float>& audio);
+
+ int getInputSamplesSize();
+ int getSlidingWindowOffset();
+
+ // Exposing hardcoded constant as it can only be derived from model knowledge and not from model itself
+ // Will need to be refactored so that hard coded values are not defined outside of model settings
+ int SLIDING_WINDOW_OFFSET;
/**
* @brief Executes inference
@@ -60,9 +56,9 @@ public:
* @param[out] result - raw inference results.
*/
template<typename T>
- void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result)
+ void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result)
{
- size_t data_bytes = sizeof(std::vector<T>) + (sizeof(T) * preprocessedData.size());
+ size_t data_bytes = sizeof(T) * preprocessedData.size();
m_executor->Run(preprocessedData.data(), data_bytes, result);
}
@@ -78,9 +74,9 @@ public:
*/
template<typename T>
void PostProcessing(common::InferenceResults<int8_t>& inferenceResult,
- bool& isFirstWindow,
- bool isLastWindow,
- std::string currentRContext)
+ bool& isFirstWindow,
+ bool isLastWindow,
+ std::string currentRContext)
{
int rowLength = 29;
int middleContextStart = 49;
@@ -92,17 +88,17 @@ public:
std::vector<T> contextToProcess;
// If isFirstWindow we keep the left context of the output
- if(isFirstWindow)
+ if (isFirstWindow)
{
std::vector<T> chunk(&inferenceResult[0][leftContextStart],
- &inferenceResult[0][middleContextEnd * rowLength]);
+ &inferenceResult[0][middleContextEnd * rowLength]);
contextToProcess = chunk;
}
- // Else we only keep the middle context of the output
- else
+ else
{
+ // Else we only keep the middle context of the output
std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength],
- &inferenceResult[0][middleContextEnd * rowLength]);
+ &inferenceResult[0][middleContextEnd * rowLength]);
contextToProcess = chunk;
}
std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess);
@@ -110,10 +106,10 @@ public:
std::cout << output << std::flush;
// If this is the last window, we print the right context of the output
- if(isLastWindow)
+ if (isLastWindow)
{
- std::vector<T> rContext(&inferenceResult[0][rightContextStart*rowLength],
- &inferenceResult[0][rightContextEnd * rowLength]);
+ std::vector<T> rContext(&inferenceResult[0][rightContextStart * rowLength],
+ &inferenceResult[0][rightContextEnd * rowLength]);
currentRContext = this->m_decoder->DecodeOutput(rContext);
std::cout << currentRContext << std::endl;
}
@@ -122,6 +118,7 @@ public:
protected:
std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor;
std::unique_ptr<Decoder> m_decoder;
+ std::unique_ptr<Wav2LetterPreprocessor> m_preProcessor;
};
using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>;
@@ -136,4 +133,4 @@ using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>;
*/
IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels);
-}// namespace asr \ No newline at end of file
+} // namespace asr \ No newline at end of file