From c291144b7f08c21d08cdaf79cc64dc420ca70070 Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Fri, 22 Apr 2022 09:08:21 +0100 Subject: MLECO-3077: Add ASR use case API * Minor adjustments to doc strings in KWS * Remove unused score threshold in KWS Signed-off-by: Richard Burton Change-Id: Ie1c5bf6f7bdbebb853b6a10cb7ba1c4a1d9a76c9 --- source/use_case/asr/include/AsrResult.hpp | 2 +- source/use_case/asr/include/Wav2LetterModel.hpp | 3 + .../use_case/asr/include/Wav2LetterPostprocess.hpp | 115 +++++++------- .../use_case/asr/include/Wav2LetterPreprocess.hpp | 96 +++++------- source/use_case/asr/src/MainLoop.cc | 85 +---------- source/use_case/asr/src/UseCaseHandler.cc | 166 ++++++++++----------- source/use_case/asr/src/Wav2LetterMfcc.cc | 1 + source/use_case/asr/src/Wav2LetterModel.cc | 1 + source/use_case/asr/src/Wav2LetterPostprocess.cc | 153 ++++++++++++------- source/use_case/asr/src/Wav2LetterPreprocess.cc | 106 ++++++------- source/use_case/kws/include/KwsProcessing.hpp | 23 ++- source/use_case/kws/src/KwsProcessing.cc | 5 +- source/use_case/kws/src/UseCaseHandler.cc | 4 +- tests/common/PlatformMathTests.cpp | 33 +++- tests/use_case/asr/AsrFeaturesTests.cc | 52 +------ tests/use_case/asr/Wav2LetterPostprocessingTest.cc | 124 ++++++++------- tests/use_case/asr/Wav2LetterPreprocessingTest.cc | 120 +++++++-------- 17 files changed, 495 insertions(+), 594 deletions(-) diff --git a/source/use_case/asr/include/AsrResult.hpp b/source/use_case/asr/include/AsrResult.hpp index b12ed7d..ed826d0 100644 --- a/source/use_case/asr/include/AsrResult.hpp +++ b/source/use_case/asr/include/AsrResult.hpp @@ -25,7 +25,7 @@ namespace arm { namespace app { namespace asr { - using ResultVec = std::vector < arm::app::ClassificationResult >; + using ResultVec = std::vector; /* Structure for holding ASR result. */ class AsrResult { diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp index 55395b9..895df2b 100644 --- a/source/use_case/asr/include/Wav2LetterModel.hpp +++ b/source/use_case/asr/include/Wav2LetterModel.hpp @@ -36,6 +36,9 @@ namespace app { static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; + static constexpr uint32_t ms_blankTokenIdx = 28; + static constexpr uint32_t ms_numMfccFeatures = 13; + protected: /** @brief Gets the reference to op resolver interface class. */ const tflite::MicroOpResolver& GetOpResolver() override; diff --git a/source/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/use_case/asr/include/Wav2LetterPostprocess.hpp index 29eb548..45defa5 100644 --- a/source/use_case/asr/include/Wav2LetterPostprocess.hpp +++ b/source/use_case/asr/include/Wav2LetterPostprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,93 +17,90 @@ #ifndef ASR_WAV2LETTER_POSTPROCESS_HPP #define ASR_WAV2LETTER_POSTPROCESS_HPP -#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "BaseProcessing.hpp" +#include "AsrClassifier.hpp" +#include "AsrResult.hpp" #include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /** * @brief Helper class to manage tensor post-processing for "wav2letter" * output. */ - class Postprocess { + class ASRPostProcess : public BasePostProcess { public: + bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ + /** - * @brief Constructor - * @param[in] contextLen Left and right context length for - * output tensor. - * @param[in] innerLen This is the length of the section - * between left and right context. - * @param[in] blankTokenIdx Blank token index. + * @brief Constructor + * @param[in] outputTensor Pointer to the output Tensor. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] result Vector of classification results to store decoded outputs. + * @param[in] outputContextLen Left/right context length for output tensor. + * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. + * @param[in] reductionAxis The axis that the logits of each time step is on. **/ - Postprocess(uint32_t contextLen, - uint32_t innerLen, - uint32_t blankTokenIdx); - - Postprocess() = delete; - ~Postprocess() = default; + ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor, + const std::vector& labels, asr::ResultVec& result, + uint32_t outputContextLen, + uint32_t blankTokenIdx, uint32_t reductionAxis); /** - * @brief Erases the required part of the tensor based - * on context lengths set up during initialisation. - * @param[in] tensor Pointer to the tensor. - * @param[in] axisIdx Index of the axis on which erase is - * performed. - * @param[in] lastIteration Flag to signal this is the - * last iteration in which case - * the right context is preserved. - * @return true if successful, false otherwise. - */ - bool Invoke(TfLiteTensor* tensor, - uint32_t axisIdx, - bool lastIteration = false); + * @brief Should perform post-processing of the result of inference then + * populate ASR result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + + /** @brief Gets the output inner length for post-processing. */ + static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen); + + /** @brief Gets the output context length (left/right) for post-processing. */ + static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen); + + /** @brief Gets the number of feature vectors to be computed. */ + static uint32_t GetNumFeatureVectors(const Model& model); private: - uint32_t m_contextLen; /* lengths of left and right contexts. */ - uint32_t m_innerLen; /* Length of inner context. */ - uint32_t m_totalLen; /* Total length of the required axis. */ - uint32_t m_countIterations; /* Current number of iterations. */ - uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + AsrClassifier& m_classifier; /* ASR Classifier object. */ + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + const std::vector& m_labels; /* ASR Labels. */ + asr::ResultVec & m_results; /* Results vector for a single inference. */ + uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */ + uint32_t m_outputInnerLen; /* Length of output inner context. */ + uint32_t m_totalLen; /* Total length of the required axis. */ + uint32_t m_countIterations; /* Current number of iterations. */ + uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */ + /** - * @brief Checks if the tensor and axis index are valid - * inputs to the object - based on how it has been - * initialised. - * @return true if valid, false otherwise. + * @brief Checks if the tensor and axis index are valid + * inputs to the object - based on how it has been initialised. + * @return true if valid, false otherwise. */ bool IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const; + uint32_t axisIdx) const; /** - * @brief Gets the tensor data element size in bytes based - * on the tensor type. - * @return Size in bytes, 0 if not supported. + * @brief Gets the tensor data element size in bytes based + * on the tensor type. + * @return Size in bytes, 0 if not supported. */ static uint32_t GetTensorElementSize(TfLiteTensor* tensor); /** - * @brief Erases sections from the data assuming row-wise - * arrangement along the context axis. - * @return true if successful, false otherwise. + * @brief Erases sections from the data assuming row-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. */ bool EraseSectionsRowWise(uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration); - - /** - * @brief Erases sections from the data assuming col-wise - * arrangement along the context axis. - * @return true if successful, false otherwise. - */ - static bool EraseSectionsColWise(const uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration); + uint32_t strideSzBytes, + bool lastIteration); }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ diff --git a/source/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/use_case/asr/include/Wav2LetterPreprocess.hpp index 13d1589..8c12b3d 100644 --- a/source/use_case/asr/include/Wav2LetterPreprocess.hpp +++ b/source/use_case/asr/include/Wav2LetterPreprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,49 +21,44 @@ #include "Wav2LetterMfcc.hpp" #include "AudioUtils.hpp" #include "DataStructures.hpp" +#include "BaseProcessing.hpp" #include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /* Class to facilitate pre-processing calculation for Wav2Letter model * for ASR. */ - using AudioWindow = SlidingWindow ; + using AudioWindow = audio::SlidingWindow; - class Preprocess { + class ASRPreProcess : public BasePreProcess { public: /** * @brief Constructor. - * @param[in] numMfccFeatures Number of MFCC features per window. - * @param[in] windowLen Number of elements in a window. - * @param[in] windowStride Stride (in number of elements) for - * moving the window. - * @param[in] numMfccVectors Number of MFCC vectors per window. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. + * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. + * @param[in] mfccWindowStride Number of MFCC vectors that need to be calculated + * for an inference. */ - Preprocess( - uint32_t numMfccFeatures, - uint32_t windowLen, - uint32_t windowStride, - uint32_t numMfccVectors); - Preprocess() = delete; - ~Preprocess() = default; + ASRPreProcess(TfLiteTensor* inputTensor, + uint32_t numMfccFeatures, + uint32_t audioWindowLen, + uint32_t mfccWindowLen, + uint32_t mfccWindowStride); /** * @brief Calculates the features required from audio data. This * includes MFCC, first and second order deltas, * normalisation and finally, quantisation. The tensor is - * populated with feature from a given window placed along + * populated with features from a given window placed along * in a single row. * @param[in] audioData Pointer to the first element of audio data. * @param[in] audioDataLen Number of elements in the audio data. - * @param[in] tensor Tensor to be populated. * @return true if successful, false in case of error. */ - bool Invoke(const int16_t * audioData, - uint32_t audioDataLen, - TfLiteTensor * tensor); + bool DoPreProcess(const void* audioData, size_t audioDataLen) override; protected: /** @@ -80,32 +75,16 @@ namespace asr { Array2d& delta2); /** - * @brief Given a 2D vector of floats, computes the mean. - * @param[in] vec Vctor of vector of floats. - * @return Mean value. - */ - static float GetMean(Array2d& vec); - - /** - * @brief Given a 2D vector of floats, computes the stddev. - * @param[in] vec Vector of vector of floats. - * @param[in] mean Mean value of the vector passed in. - * @return stddev value. - */ - static float GetStdDev(Array2d& vec, - const float mean); - - /** - * @brief Given a 2D vector of floats, normalises it using - * the mean and the stddev. + * @brief Given a 2D vector of floats, rescale it to have mean of 0 and + * standard deviation of 1. * @param[in,out] vec Vector of vector of floats. */ - static void NormaliseVec(Array2d& vec); + static void StandardizeVecF32(Array2d& vec); /** - * @brief Normalises the MFCC and delta buffers. + * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1. */ - void Normalise(); + void Standarize(); /** * @brief Given the quantisation and data type limits, computes @@ -139,7 +118,7 @@ namespace asr { */ template bool Quantise( - T * outputBuf, + T* outputBuf, const uint32_t outputBufSz, const float quantScale, const int quantOffset) @@ -160,15 +139,15 @@ namespace asr { const float maxVal = std::numeric_limits::max(); /* Need to transpose while copying and concatenating the tensor. */ - for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) { + for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { - *outputBufMfcc++ = static_cast(Preprocess::GetQuantElem( + *outputBufMfcc++ = static_cast(ASRPreProcess::GetQuantElem( this->m_mfccBuf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD1++ = static_cast(Preprocess::GetQuantElem( + *outputBufD1++ = static_cast(ASRPreProcess::GetQuantElem( this->m_delta1Buf(i, j), quantScale, quantOffset, minVal, maxVal)); - *outputBufD2++ = static_cast(Preprocess::GetQuantElem( + *outputBufD2++ = static_cast(ASRPreProcess::GetQuantElem( this->m_delta2Buf(i, j), quantScale, quantOffset, minVal, maxVal)); } @@ -181,23 +160,22 @@ namespace asr { } private: - Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + TfLiteTensor* m_inputTensor; /* Model input tensor. */ /* Actual buffers to be populated. */ - Array2d m_mfccBuf; /* Contiguous buffer 1D: MFCC */ - Array2d m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ - Array2d m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ + Array2d m_mfccBuf; /* Contiguous buffer 1D: MFCC */ + Array2d m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ + Array2d m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ - uint32_t m_windowLen; /* Window length for MFCC. */ - uint32_t m_windowStride; /* Window stride len for MFCC. */ - uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ - uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */ - AudioWindow m_window; /* Sliding window. */ + uint32_t m_mfccWindowLen; /* Window length for MFCC. */ + uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */ + uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ + uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */ + AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */ }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ diff --git a/source/use_case/asr/src/MainLoop.cc b/source/use_case/asr/src/MainLoop.cc index 51b0b18..a1a9540 100644 --- a/source/use_case/asr/src/MainLoop.cc +++ b/source/use_case/asr/src/MainLoop.cc @@ -14,15 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "Labels.hpp" /* For label strings. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ #include "Wav2LetterModel.hpp" /* Model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "AsrClassifier.hpp" /* Classifier. */ #include "InputFiles.hpp" /* Generated audio clip header. */ -#include "Wav2LetterPreprocess.hpp" /* Pre-processing class. */ -#include "Wav2LetterPostprocess.hpp" /* Post-processing class. */ #include "log_macros.h" enum opcodes @@ -48,23 +45,9 @@ static void DisplayMenu() fflush(stdout); } -/** @brief Verify input and output tensor are of certain min dimensions. */ +/** @brief Verify input and output tensor are of certain min dimensions. */ static bool VerifyTensorDimensions(const arm::app::Model& model); -/** @brief Gets the number of MFCC features for a single window. */ -static uint32_t GetNumMfccFeatures(const arm::app::Model& model); - -/** @brief Gets the number of MFCC feature vectors to be computed. */ -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); - -/** @brief Gets the output context length (left and right) for post-processing. */ -static uint32_t GetOutputContextLen(const arm::app::Model& model, - uint32_t inputCtxLen); - -/** @brief Gets the output inner length for post-processing. */ -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - uint32_t outputCtxLen); - void main_loop() { arm::app::Wav2LetterModel model; /* Model wrapper object. */ @@ -78,21 +61,6 @@ void main_loop() return; } - /* Initialise pre-processing. */ - arm::app::audio::asr::Preprocess prep( - GetNumMfccFeatures(model), - g_FrameLength, - g_FrameStride, - GetNumMfccFeatureVectors(model)); - - /* Initialise post-processing. */ - const uint32_t outputCtxLen = GetOutputContextLen(model, g_ctxLen); - const uint32_t blankTokenIdx = 28; - arm::app::audio::asr::Postprocess postp( - outputCtxLen, - GetOutputInnerLen(model, outputCtxLen), - blankTokenIdx); - /* Instantiate application context. */ arm::app::ApplicationContext caseContext; std::vector labels; @@ -109,8 +77,6 @@ void main_loop() caseContext.Set("ctxLen", g_ctxLen); /* Left and right context length (MFCC feat vectors). */ caseContext.Set&>("labels", labels); caseContext.Set("classifier", classifier); - caseContext.Set("preprocess", prep); - caseContext.Set("postprocess", postp); bool executionSuccessful = true; constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; @@ -184,52 +150,3 @@ static bool VerifyTensorDimensions(const arm::app::Model& model) return true; } - -static uint32_t GetNumMfccFeatures(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; - if (0 != inputCols % 3) { - printf_err("Number of input columns is not a multiple of 3\n"); - } - return std::max(inputCols/3, 0); -} - -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - return std::max(inputRows, 0); -} - -static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) -{ - const uint32_t inputRows = GetNumMfccFeatureVectors(model); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - - /* Check to make sure that the input tensor supports the above - * context and inner lengths. */ - if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { - printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", - inputCtxLen); - return 0; - } - - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - - const float tensorColRatio = static_cast(inputRows)/ - static_cast(outputRows); - - return std::round(static_cast(inputCtxLen)/tensorColRatio); -} - -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - const uint32_t outputCtxLen) -{ - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - return (outputRows - (2 * outputCtxLen)); -} diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc index 420f725..7fe959b 100644 --- a/source/use_case/asr/src/UseCaseHandler.cc +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -20,7 +20,6 @@ #include "AsrClassifier.hpp" #include "Wav2LetterModel.hpp" #include "hal.h" -#include "Wav2LetterMfcc.hpp" #include "AudioUtils.hpp" #include "ImageUtils.hpp" #include "UseCaseCommonUtils.hpp" @@ -34,68 +33,63 @@ namespace arm { namespace app { /** - * @brief Presents inference results using the data presentation - * object. - * @param[in] results Vector of classification results to be displayed. + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(const std::vector& results); + static bool PresentInferenceResult(const std::vector& results); - /* Audio inference classification handler. */ + /* ASR inference handler. */ bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { - constexpr uint32_t dataPsnTxtInfStartX = 20; - constexpr uint32_t dataPsnTxtInfStartY = 40; - - hal_lcd_clear(COLOR_BLACK); - + auto& model = ctx.Get("model"); auto& profiler = ctx.Get("profiler"); - + auto mfccFrameLen = ctx.Get("frameLength"); + auto mfccFrameStride = ctx.Get("frameStride"); + auto scoreThreshold = ctx.Get("scoreThreshold"); + auto inputCtxLen = ctx.Get("ctxLen"); /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { return false; } } + auto initialClipIdx = ctx.Get("clipIndex"); + constexpr uint32_t dataPsnTxtInfStartX = 20; + constexpr uint32_t dataPsnTxtInfStartY = 40; - /* Get model reference. */ - auto& model = ctx.Get("model"); if (!model.IsInited()) { printf_err("Model is not initialised! Terminating processing.\n"); return false; } - /* Get score threshold to be applied for the classifier (post-inference). */ - auto scoreThreshold = ctx.Get("scoreThreshold"); - - /* Get tensors. Dimensions of the tensor should have been verified by + /* Get input shape. Dimensions of the tensor should have been verified by * the callee. */ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; + TfLiteIntArray* inputShape = model.GetInputShape(0); - /* Populate MFCC related parameters. */ - auto mfccParamsWinLen = ctx.Get("frameLength"); - auto mfccParamsWinStride = ctx.Get("frameStride"); - - /* Populate ASR inference context and inner lengths for input. */ - auto inputCtxLen = ctx.Get("ctxLen"); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + const uint32_t inputRowsSize = inputShape->data[Wav2LetterModel::ms_inputRowsIdx]; + const uint32_t inputInnerLen = inputRowsSize - (2 * inputCtxLen); /* Audio data stride corresponds to inputInnerLen feature vectors. */ - const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen); - const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride; - const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); + const uint32_t audioDataWindowLen = (inputRowsSize - 1) * mfccFrameStride + (mfccFrameLen); + const uint32_t audioDataWindowStride = inputInnerLen * mfccFrameStride; + + /* NOTE: This is only used for time stamp calculation. */ + const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq); - /* Get pre/post-processing objects. */ - auto& prep = ctx.Get("preprocess"); - auto& postp = ctx.Get("postprocess"); + /* Set up pre and post-processing objects. */ + ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride); - /* Set default reduction axis for post-processing. */ - const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + std::vector singleInfResult; + const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen); + ASRPostProcess postProcess = ASRPostProcess(ctx.Get("classifier"), + model.GetOutputTensor(0), ctx.Get&>("labels"), + singleInfResult, outputCtxLen, + Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx + ); - /* Audio clip start index. */ - auto startClipIdx = ctx.Get("clipIndex"); + UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model); /* Loop to process audio clips. */ do { @@ -109,44 +103,41 @@ namespace app { const uint32_t audioArrSize = get_audio_array_size(currentIndex); if (!audioArr) { - printf_err("Invalid audio array pointer\n"); + printf_err("Invalid audio array pointer.\n"); return false; } - /* Audio clip must have enough samples to produce 1 MFCC feature. */ - if (audioArrSize < mfccParamsWinLen) { + /* Audio clip needs enough samples to produce at least 1 MFCC feature. */ + if (audioArrSize < mfccFrameLen) { printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n", - mfccParamsWinLen); + mfccFrameLen); return false; } - /* Initialise an audio slider. */ + /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::FractionalSlidingWindow( - audioArr, - audioArrSize, - audioParamsWinLen, - audioParamsWinStride); + audioArr, audioArrSize, + audioDataWindowLen, audioDataWindowStride); - /* Declare a container for results. */ - std::vector results; + /* Declare a container for final results. */ + std::vector finalResults; /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex)); - size_t inferenceWindowLen = audioParamsWinLen; + size_t inferenceWindowLen = audioDataWindowLen; /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { - /* If not enough audio see how much can be sent for processing. */ + /* If not enough audio, see how much can be sent for processing. */ size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); - if (nextStartIndex + audioParamsWinLen > audioArrSize) { + if (nextStartIndex + audioDataWindowLen > audioArrSize) { inferenceWindowLen = audioArrSize - nextStartIndex; } @@ -155,46 +146,40 @@ namespace app { info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, static_cast(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); - /* Calculate MFCCs, deltas and populate the input tensor. */ - prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor); + /* Run the pre-processing, inference and post-processing. */ + runner.PreProcess(inferenceWindow, inferenceWindowLen); - /* Run inference over this audio clip sliding window. */ - if (!RunInference(model, profiler)) { + profiler.StartProfiling("Inference"); + if (!runner.RunInference()) { return false; } + profiler.StopProfiling(); - /* Post-process. */ - postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext()); - - /* Get results. */ - std::vector classificationResult; - auto& classifier = ctx.Get("classifier"); - classifier.GetClassificationResults( - outputTensor, classificationResult, - ctx.Get&>("labels"), 1); + postProcess.m_lastIteration = !audioDataSlider.HasNext(); + if (!runner.PostProcess()) { + return false; + } - results.emplace_back(asr::AsrResult(classificationResult, - (audioDataSlider.Index() * - audioParamsSecondsPerSample * - audioParamsWinStride), - audioDataSlider.Index(), scoreThreshold)); + /* Add results from this window to our final results vector. */ + finalResults.emplace_back(asr::AsrResult(singleInfResult, + (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride), + audioDataSlider.Index(), scoreThreshold)); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(outputTensor, - outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + armDumpTensor(outputTensor, + outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ - - } + } /* while (audioDataSlider.HasNext()) */ /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - ctx.Set>("results", results); + ctx.Set>("results", finalResults); - if (!PresentInferenceResult(results)) { + if (!PresentInferenceResult(finalResults)) { return false; } @@ -202,13 +187,13 @@ namespace app { IncrementAppCtxIfmIdx(ctx,"clipIndex"); - } while (runAll && ctx.Get("clipIndex") != startClipIdx); + } while (runAll && ctx.Get("clipIndex") != initialClipIdx); return true; } - static bool PresentInferenceResult(const std::vector& results) + static bool PresentInferenceResult(const std::vector& results) { constexpr uint32_t dataPsnTxtStartX1 = 20; constexpr uint32_t dataPsnTxtStartY1 = 60; @@ -219,15 +204,15 @@ namespace app { info("Final results:\n"); info("Total number of inferences: %zu\n", results.size()); /* Results from multiple inferences should be combined before processing. */ - std::vector combinedResults; - for (auto& result : results) { + std::vector combinedResults; + for (const auto& result : results) { combinedResults.insert(combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end()); } /* Get each inference result string using the decoder. */ - for (const auto & result : results) { + for (const auto& result : results) { std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n", @@ -238,10 +223,9 @@ namespace app { /* Get the decoded result for the combined result. */ std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); - hal_lcd_display_text( - finalResultStr.c_str(), finalResultStr.size(), - dataPsnTxtStartX1, dataPsnTxtStartY1, - allow_multiple_lines); + hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, + allow_multiple_lines); info("Complete recognition: %s\n", finalResultStr.c_str()); return true; diff --git a/source/use_case/asr/src/Wav2LetterMfcc.cc b/source/use_case/asr/src/Wav2LetterMfcc.cc index 1bcaa66..bb29b0f 100644 --- a/source/use_case/asr/src/Wav2LetterMfcc.cc +++ b/source/use_case/asr/src/Wav2LetterMfcc.cc @@ -15,6 +15,7 @@ * limitations under the License. */ #include "Wav2LetterMfcc.hpp" + #include "PlatformMath.hpp" #include "log_macros.h" diff --git a/source/use_case/asr/src/Wav2LetterModel.cc b/source/use_case/asr/src/Wav2LetterModel.cc index 766bce9..8b38f4f 100644 --- a/source/use_case/asr/src/Wav2LetterModel.cc +++ b/source/use_case/asr/src/Wav2LetterModel.cc @@ -15,6 +15,7 @@ * limitations under the License. */ #include "Wav2LetterModel.hpp" + #include "log_macros.h" diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc index 0392061..e3e1999 100644 --- a/source/use_case/asr/src/Wav2LetterPostprocess.cc +++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,67 +15,71 @@ * limitations under the License. */ #include "Wav2LetterPostprocess.hpp" + #include "Wav2LetterModel.hpp" #include "log_macros.h" +#include + namespace arm { namespace app { -namespace audio { -namespace asr { - - Postprocess::Postprocess(const uint32_t contextLen, - const uint32_t innerLen, - const uint32_t blankTokenIdx) - : m_contextLen(contextLen), - m_innerLen(innerLen), - m_totalLen(2 * this->m_contextLen + this->m_innerLen), + + ASRPostProcess::ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor, + const std::vector& labels, std::vector& results, + const uint32_t outputContextLen, + const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx + ): + m_classifier(classifier), + m_outputTensor(outputTensor), + m_labels{labels}, + m_results(results), + m_outputContextLen(outputContextLen), m_countIterations(0), - m_blankTokenIdx(blankTokenIdx) - {} + m_blankTokenIdx(blankTokenIdx), + m_reductionAxisIdx(reductionAxisIdx) + { + this->m_outputInnerLen = ASRPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); + } - bool Postprocess::Invoke(TfLiteTensor* tensor, - const uint32_t axisIdx, - const bool lastIteration) + bool ASRPostProcess::DoPostProcess() { /* Basic checks. */ - if (!this->IsInputValid(tensor, axisIdx)) { + if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { return false; } /* Irrespective of tensor type, we use unsigned "byte" */ - uint8_t* ptrData = tflite::GetTensorData(tensor); - const uint32_t elemSz = this->GetTensorElementSize(tensor); + auto* ptrData = tflite::GetTensorData(this->m_outputTensor); + const uint32_t elemSz = ASRPostProcess::GetTensorElementSize(this->m_outputTensor); /* Other sanity checks. */ if (0 == elemSz) { printf_err("Tensor type not supported for post processing\n"); return false; - } else if (elemSz * this->m_totalLen > tensor->bytes) { + } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) { printf_err("Insufficient number of tensor bytes\n"); return false; } /* Which axis do we need to process? */ - switch (axisIdx) { - case arm::app::Wav2LetterModel::ms_outputRowsIdx: - return this->EraseSectionsRowWise(ptrData, - elemSz * - tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx], - lastIteration); - case arm::app::Wav2LetterModel::ms_outputColsIdx: - return this->EraseSectionsColWise(ptrData, - elemSz * - tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx], - lastIteration); + switch (this->m_reductionAxisIdx) { + case Wav2LetterModel::ms_outputRowsIdx: + this->EraseSectionsRowWise( + ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx], + this->m_lastIteration); + break; default: - printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx); + printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx); + return false; } + this->m_classifier.GetClassificationResults(this->m_outputTensor, + this->m_results, this->m_labels, 1); - return false; + return true; } - bool Postprocess::IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const + bool ASRPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const { if (nullptr == tensor) { return false; @@ -89,15 +93,15 @@ namespace asr { if (static_cast(this->m_totalLen) != tensor->dims->data[axisIdx]) { - printf_err("Unexpected tensor dimension for axis %d, \n", - tensor->dims->data[axisIdx]); + printf_err("Unexpected tensor dimension for axis %d, got %d, \n", + axisIdx, tensor->dims->data[axisIdx]); return false; } return true; } - uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor) + uint32_t ASRPostProcess::GetTensorElementSize(TfLiteTensor* tensor) { switch(tensor->type) { case kTfLiteUInt8: @@ -116,30 +120,30 @@ namespace asr { return 0; } - bool Postprocess::EraseSectionsRowWise( - uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration) + bool ASRPostProcess::EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) { /* In this case, the "zero-ing" is quite simple as the region * to be zeroed sits in contiguous memory (row-major). */ - const uint32_t eraseLen = strideSzBytes * this->m_contextLen; + const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen; /* Erase left context? */ if (this->m_countIterations > 0) { /* Set output of each classification window to the blank token. */ std::memset(ptrData, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } /* Erase right context? */ if (false == lastIteration) { - uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen)); + uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen)); /* Set output of each classification window to the blank token. */ std::memset(rightCtxPtr, 0, eraseLen); - for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) { + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; } } @@ -153,19 +157,56 @@ namespace asr { return true; } - bool Postprocess::EraseSectionsColWise( - const uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration) + uint32_t ASRPostProcess::GetNumFeatureVectors(const Model& model) + { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); + if (inputRows == 0) { + printf_err("Error getting number of input rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_inputRowsIdx); + } + return inputRows; + } + + uint32_t ASRPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) { - /* Not implemented. */ - UNUSED(ptrData); - UNUSED(strideSzBytes); - UNUSED(lastIteration); - return false; + const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + } + int innerLen = (outputRows - (2 * outputCtxLen)); + + return std::max(innerLen, 0); + } + + uint32_t ASRPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + { + const uint32_t inputRows = ASRPostProcess::GetNumFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above + * context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + return 0; + } + + const float inOutRowRatio = static_cast(inputRows) / + static_cast(outputRows); + + return std::round(static_cast(inputCtxLen) / inOutRowRatio); } -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/asr/src/Wav2LetterPreprocess.cc b/source/use_case/asr/src/Wav2LetterPreprocess.cc index e5ac3ca..590d08a 100644 --- a/source/use_case/asr/src/Wav2LetterPreprocess.cc +++ b/source/use_case/asr/src/Wav2LetterPreprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,37 +24,31 @@ namespace arm { namespace app { -namespace audio { -namespace asr { - - Preprocess::Preprocess( - const uint32_t numMfccFeatures, - const uint32_t windowLen, - const uint32_t windowStride, - const uint32_t numMfccVectors): - m_mfcc(numMfccFeatures, windowLen), - m_mfccBuf(numMfccFeatures, numMfccVectors), - m_delta1Buf(numMfccFeatures, numMfccVectors), - m_delta2Buf(numMfccFeatures, numMfccVectors), - m_windowLen(windowLen), - m_windowStride(windowStride), + + ASRPreProcess::ASRPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, + const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, + const uint32_t mfccWindowStride + ): + m_mfcc(numMfccFeatures, mfccWindowLen), + m_inputTensor(inputTensor), + m_mfccBuf(numMfccFeatures, numFeatureFrames), + m_delta1Buf(numMfccFeatures, numFeatureFrames), + m_delta2Buf(numMfccFeatures, numFeatureFrames), + m_mfccWindowLen(mfccWindowLen), + m_mfccWindowStride(mfccWindowStride), m_numMfccFeats(numMfccFeatures), - m_numFeatVectors(numMfccVectors), - m_window() + m_numFeatureFrames(numFeatureFrames) { - if (numMfccFeatures > 0 && windowLen > 0) { + if (numMfccFeatures > 0 && mfccWindowLen > 0) { this->m_mfcc.Init(); } } - bool Preprocess::Invoke( - const int16_t* audioData, - const uint32_t audioDataLen, - TfLiteTensor* tensor) + bool ASRPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) { - this->m_window = SlidingWindow( - audioData, audioDataLen, - this->m_windowLen, this->m_windowStride); + this->m_mfccSlidingWindow = audio::SlidingWindow( + static_cast(audioData), audioDataLen, + this->m_mfccWindowLen, this->m_mfccWindowStride); uint32_t mfccBufIdx = 0; @@ -62,12 +56,12 @@ namespace asr { std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f); std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f); - /* While we can slide over the window. */ - while (this->m_window.HasNext()) { - const int16_t* mfccWindow = this->m_window.Next(); + /* While we can slide over the audio. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); auto mfccAudioData = std::vector( mfccWindow, - mfccWindow + this->m_windowLen); + mfccWindow + this->m_mfccWindowLen); auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData); for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) { this->m_mfccBuf(i, mfccBufIdx) = mfcc[i]; @@ -76,11 +70,11 @@ namespace asr { } /* Pad MFCC if needed by adding MFCC for zeros. */ - if (mfccBufIdx != this->m_numFeatVectors) { - std::vector zerosWindow = std::vector(this->m_windowLen, 0); + if (mfccBufIdx != this->m_numFeatureFrames) { + std::vector zerosWindow = std::vector(this->m_mfccWindowLen, 0); std::vector mfccZeros = this->m_mfcc.MfccCompute(zerosWindow); - while (mfccBufIdx != this->m_numFeatVectors) { + while (mfccBufIdx != this->m_numFeatureFrames) { memcpy(&this->m_mfccBuf(0, mfccBufIdx), mfccZeros.data(), sizeof(float) * m_numMfccFeats); ++mfccBufIdx; @@ -88,39 +82,37 @@ namespace asr { } /* Compute first and second order deltas from MFCCs. */ - Preprocess::ComputeDeltas(this->m_mfccBuf, - this->m_delta1Buf, - this->m_delta2Buf); + ASRPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); - /* Normalise. */ - this->Normalise(); + /* Standardize calculated features. */ + this->Standarize(); /* Quantise. */ - QuantParams quantParams = GetTensorQuantParams(tensor); + QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor); if (0 == quantParams.scale) { printf_err("Quantisation scale can't be 0\n"); return false; } - switch(tensor->type) { + switch(this->m_inputTensor->type) { case kTfLiteUInt8: return this->Quantise( - tflite::GetTensorData(tensor), tensor->bytes, + tflite::GetTensorData(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); case kTfLiteInt8: return this->Quantise( - tflite::GetTensorData(tensor), tensor->bytes, + tflite::GetTensorData(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); default: printf_err("Unsupported tensor type %s\n", - TfLiteTypeGetName(tensor->type)); + TfLiteTypeGetName(this->m_inputTensor->type)); } return false; } - bool Preprocess::ComputeDeltas(Array2d& mfcc, + bool ASRPreProcess::ComputeDeltas(Array2d& mfcc, Array2d& delta1, Array2d& delta2) { @@ -175,20 +167,10 @@ namespace asr { return true; } - float Preprocess::GetMean(Array2d& vec) - { - return math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); - } - - float Preprocess::GetStdDev(Array2d& vec, const float mean) - { - return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); - } - - void Preprocess::NormaliseVec(Array2d& vec) + void ASRPreProcess::StandardizeVecF32(Array2d& vec) { - auto mean = Preprocess::GetMean(vec); - auto stddev = Preprocess::GetStdDev(vec, mean); + auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); debug("Mean: %f, Stddev: %f\n", mean, stddev); if (stddev == 0) { @@ -204,14 +186,14 @@ namespace asr { } } - void Preprocess::Normalise() + void ASRPreProcess::Standarize() { - Preprocess::NormaliseVec(this->m_mfccBuf); - Preprocess::NormaliseVec(this->m_delta1Buf); - Preprocess::NormaliseVec(this->m_delta2Buf); + ASRPreProcess::StandardizeVecF32(this->m_mfccBuf); + ASRPreProcess::StandardizeVecF32(this->m_delta1Buf); + ASRPreProcess::StandardizeVecF32(this->m_delta2Buf); } - float Preprocess::GetQuantElem( + float ASRPreProcess::GetQuantElem( const float elem, const float quantScale, const int quantOffset, @@ -222,7 +204,5 @@ namespace asr { return std::min(std::max(val, minVal), maxVal); } -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp index abf20ab..ddf38c1 100644 --- a/source/use_case/kws/include/KwsProcessing.hpp +++ b/source/use_case/kws/include/KwsProcessing.hpp @@ -38,7 +38,7 @@ namespace app { public: /** * @brief Constructor - * @param[in] model Pointer to the the KWS Model object. + * @param[in] model Pointer to the KWS Model object. * @param[in] numFeatures How many MFCC features to use. * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when * sliding a window through the audio sample. @@ -107,24 +107,21 @@ namespace app { std::vector& m_results; public: - const float m_scoreThreshold; /** - * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the the Image classification Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in] results Vector of classification results to store decoded outputs. - * @param[in] scoreThreshold Predicted model score must be larger than this value to be accepted. + * @brief Constructor + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] model Pointer to the KWS Model object. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] results Vector of classification results to store decoded outputs. **/ KWSPostProcess(Classifier& classifier, Model* model, const std::vector& labels, - std::vector& results, - float scoreThreshold); + std::vector& results); /** - * @brief Should perform post-processing of the result of inference then populate - * populate KWS result data for any later use. - * @return true if successful, false otherwise. + * @brief Should perform post-processing of the result of inference then + * populate KWS result data for any later use. + * @return true if successful, false otherwise. **/ bool DoPostProcess() override; }; diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc index b6b230c..14f9fce 100644 --- a/source/use_case/kws/src/KwsProcessing.cc +++ b/source/use_case/kws/src/KwsProcessing.cc @@ -197,11 +197,10 @@ namespace app { KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model, const std::vector& labels, - std::vector& results, float scoreThreshold) + std::vector& results) :m_kwsClassifier{classifier}, m_labels{labels}, - m_results{results}, - m_scoreThreshold{scoreThreshold} + m_results{results} { if (!model->IsInited()) { printf_err("Model is not initialised!.\n"); diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index 350d34b..e73a2c3 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -93,7 +93,7 @@ namespace app { std::vector singleInfResult; KWSPostProcess postprocess = KWSPostProcess(ctx.Get("classifier"), &model, ctx.Get&>("labels"), - singleInfResult, scoreThreshold); + singleInfResult); UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); @@ -146,7 +146,7 @@ namespace app { /* Add results from this window to our final results vector. */ finalResults.emplace_back(kws::KwsResult(singleInfResult, audioDataSlider.Index() * secondsPerSample * preprocess.m_audioDataStride, - audioDataSlider.Index(), postprocess.m_scoreThreshold)); + audioDataSlider.Index(), scoreThreshold)); #if VERIFY_TEST_OUTPUT TfLiteTensor* outputTensor = model.GetOutputTensor(0); diff --git a/tests/common/PlatformMathTests.cpp b/tests/common/PlatformMathTests.cpp index ab1153f..c07cbf1 100644 --- a/tests/common/PlatformMathTests.cpp +++ b/tests/common/PlatformMathTests.cpp @@ -150,13 +150,28 @@ TEST_CASE("Test SqrtF32") TEST_CASE("Test MeanF32") { - /*Test Constants: */ + /* Test Constants: */ std::vector input {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 1.000}; /* Manually calculated mean of above vector */ float expectedResult = 0.100; CHECK (expectedResult == Approx(arm::app::math::MathUtils::MeanF32(input.data(), input.size()))); + + /* Mean of 0 */ + std::vector input2{1, 2, -1, -2}; + float expectedResult2 = 0.0f; + CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::MeanF32(input2.data(), input2.size()))); + + /* All 0s */ + std::vector input3 = std::vector(9, 0); + float expectedResult3 = 0.0f; + CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::MeanF32(input3.data(), input3.size()))); + + /* All 1s */ + std::vector input4 = std::vector(9, 1); + float expectedResult4 = 1.0f; + CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::MeanF32(input4.data(), input4.size()))); } TEST_CASE("Test StdDevF32") @@ -184,6 +199,22 @@ TEST_CASE("Test StdDevF32") float expectedResult = 0.969589282958136; CHECK (expectedResult == Approx(output)); + + /* All 0s should have 0 std dev. */ + std::vector input2 = std::vector(4, 0); + float expectedResult2 = 0.0f; + CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::StdDevF32(input2.data(), input2.size(), 0.0f))); + + /* All 1s should have 0 std dev. */ + std::vector input3 = std::vector(4, 1); + float expectedResult3 = 0.0f; + CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::StdDevF32(input3.data(), input3.size(), 1.0f))); + + /* Manually calclualted std value */ + std::vector input4 {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}; + float mean2 = (std::accumulate(input4.begin(), input4.end(), 0.0f))/float(input4.size()); + float expectedResult4 = 2.872281323; + CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::StdDevF32(input4.data(), input4.size(), mean2))); } TEST_CASE("Test FFT32") diff --git a/tests/use_case/asr/AsrFeaturesTests.cc b/tests/use_case/asr/AsrFeaturesTests.cc index 940c25f..6c23598 100644 --- a/tests/use_case/asr/AsrFeaturesTests.cc +++ b/tests/use_case/asr/AsrFeaturesTests.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,29 +23,19 @@ #include #include -class TestPreprocess : public arm::app::audio::asr::Preprocess { +class TestPreprocess : public arm::app::ASRPreProcess { public: static bool ComputeDeltas(arm::app::Array2d& mfcc, arm::app::Array2d& delta1, arm::app::Array2d& delta2) { - return Preprocess::ComputeDeltas(mfcc, delta1, delta2); - } - - static float GetMean(arm::app::Array2d& vec) - { - return Preprocess::GetMean(vec); - } - - static float GetStdDev(arm::app::Array2d& vec, const float mean) - { - return Preprocess::GetStdDev(vec, mean); + return ASRPreProcess::ComputeDeltas(mfcc, delta1, delta2); } static void NormaliseVec(arm::app::Array2d& vec) { - return Preprocess::NormaliseVec(vec); + return ASRPreProcess::StandardizeVecF32(vec); } }; @@ -126,40 +116,6 @@ TEST_CASE("Floating point asr features calculation", "[ASR]") } - SECTION("Mean") - { - std::vector> mean1vec{{1, 2}, - {-1, -2}}; - arm::app::Array2d mean1(2,2); /* {{1, 2},{-1, -2}} */ - populateArray2dWithVectorOfVector(mean1vec, mean1); - REQUIRE(0 == Approx(TestPreprocess::GetMean(mean1))); - - arm::app::Array2d mean2(2, 2); - std::fill(mean2.begin(), mean2.end(), 0.f); - REQUIRE(0 == Approx(TestPreprocess::GetMean(mean2))); - - arm::app::Array2d mean3(3,3); - std::fill(mean3.begin(), mean3.end(), 1.f); - REQUIRE(1 == Approx(TestPreprocess::GetMean(mean3))); - } - - SECTION("Std") - { - arm::app::Array2d std1(2, 2); - std::fill(std1.begin(), std1.end(), 0.f); /* {{0, 0}, {0, 0}} */ - REQUIRE(0 == Approx(TestPreprocess::GetStdDev(std1, 0))); - - std::vector> std2vec{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 0}}; - arm::app::Array2d std2(2,5); - populateArray2dWithVectorOfVector(std2vec, std2); - const float mean = TestPreprocess::GetMean(std2); - REQUIRE(2.872281323 == Approx(TestPreprocess::GetStdDev(std2, mean))); - - arm::app::Array2d std3(2,2); - std::fill(std3.begin(), std3.end(), 1.f); /* std3{{1, 1}, {1, 1}}; */ - REQUIRE(0 == Approx(TestPreprocess::GetStdDev(std3, 1))); - } - SECTION("Norm") { auto checker = [&](arm::app::Array2d& d, std::vector& g) { TestPreprocess::NormaliseVec(d); diff --git a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc index 9ed2e1b..d0b6505 100644 --- a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc +++ b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +16,7 @@ */ #include "Wav2LetterPostprocess.hpp" #include "Wav2LetterModel.hpp" +#include "ClassificationResult.hpp" #include #include @@ -47,85 +48,105 @@ TEST_CASE("Checking return value") { SECTION("Mismatched post processing parameters and tensor size") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector dummyResult; std::vector tensorShape = {1, 1, 1, 13}; std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); - REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + tensorShape, 100, tensorVec); + + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; + + REQUIRE(!post.DoPostProcess()); } SECTION("Post processing succeeds") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - - std::vector tensorShape = {1, 1, 13, 1}; - std::vector tensorVec; + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector dummyResult; + std::vector tensorShape = {1, 1, 13, 1}; + std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vector originalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); } } TEST_CASE("Postprocessing - erasing required elements") { - constexpr uint32_t ctxLen = 5; + constexpr uint32_t outputCtxLen = 5; constexpr uint32_t innerLen = 3; - constexpr uint32_t nRows = 2*ctxLen + innerLen; + constexpr uint32_t nRows = 2*outputCtxLen + innerLen; constexpr uint32_t nCols = 10; constexpr uint32_t blankTokenIdx = nCols - 1; - std::vector tensorShape = {1, 1, nRows, nCols}; + std::vector tensorShape = {1, 1, nRows, nCols}; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + std::vector dummyResult; SECTION("First and last iteration") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - - std::vector tensorVec; - TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + std::vector tensorVec; + TfLiteTensor tensor = GetTestTensor(tensorShape, 100, tensorVec); + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vectororiginalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec == tensorVec); } SECTION("Right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vector originalVec = tensorVec; + //auto tensorData = tflite::GetTensorData(tensor); /* This step should erase the right context only. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { - /* Check right context elements are zeroed. */ + /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); } /* Check left context is preserved. */ @@ -134,46 +155,47 @@ TEST_CASE("Postprocessing - erasing required elements") } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { CHECK(tensorVec[i] == originalVec[i]); } } SECTION("Left and right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector originalVec = tensorVec; /* This step should erase right context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The first and last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { /* Check left and right context elements are zeroed. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); CHECK(tensorVec[i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); CHECK(tensorVec[i*nCols + j] == 0); } } } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { /* Check left context is preserved. */ CHECK(tensorVec[i] == originalVec[i]); } @@ -181,18 +203,20 @@ TEST_CASE("Postprocessing - erasing required elements") SECTION("Try left context erase") { - /* Should not be able to erase the left context if it is the first iteration. */ - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + /* Should not be able to erase the left context if it is the first iteration. */ + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector originalVec = tensorVec; /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec == tensorVec); } diff --git a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc index 457257f..0280af6 100644 --- a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc +++ b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,55 +24,46 @@ constexpr uint32_t numMfccVectors = 10; /* Test vector output: generated using test-asr-preprocessing.py. */ int8_t expectedResult[numMfccVectors][numMfccFeatures * 3] = { - /* Feature vec 0. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, /* Delta 2. */ - - /* Feature vec 1. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 3. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 4 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, - -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15, - - /* Feature vec 5 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, - -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6, - - /* Feature vec 6. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 7. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 8. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 9. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10 + /* Feature vec 0. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, /* Delta 2. */ + /* Feature vec 1. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 3. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 4 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, + -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15}, + /* Feature vec 5 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, + -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6}, + /* Feature vec 6. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 7. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 8. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 9. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10} }; void PopulateTestWavVector(std::vector& vec) @@ -97,15 +88,16 @@ void PopulateTestWavVector(std::vector& vec) TEST_CASE("Preprocessing calculation INT8") { /* Constants. */ - const uint32_t windowLen = 512; - const uint32_t windowStride = 160; - int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; - const float quantScale = 0.1410219967365265; - const int quantOffset = -11; + const uint32_t mfccWindowLen = 512; + const uint32_t mfccWindowStride = 160; + int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; + const float quantScale = 0.1410219967365265; + const int quantOffset = -11; /* Test wav memory. */ - std::vector testWav((windowStride * numMfccVectors) + - (windowLen - windowStride)); + std::vector testWav((mfccWindowStride * numMfccVectors) + + (mfccWindowLen - mfccWindowStride) + ); /* Populate with dummy input. */ PopulateTestWavVector(testWav); @@ -115,20 +107,20 @@ TEST_CASE("Preprocessing calculation INT8") /* Initialise dimensions and the test tensor. */ TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray); - TfLiteTensor tensor = tflite::testing::CreateQuantizedTensor( - tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); + TfLiteTensor inputTensor = tflite::testing::CreateQuantizedTensor( + tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); /* Initialise pre-processing module. */ - arm::app::audio::asr::Preprocess prep{ - numMfccFeatures, windowLen, windowStride, numMfccVectors}; + arm::app::ASRPreProcess prep{&inputTensor, + numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride}; /* Invoke pre-processing. */ - REQUIRE(prep.Invoke(testWav.data(), testWav.size(), &tensor)); + REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size())); /* Wrap the tensor with a std::vector for ease. */ - auto* tensorData = tflite::GetTensorData(&tensor); + auto* tensorData = tflite::GetTensorData(&inputTensor); std::vector vecResults = - std::vector(tensorData, tensorData + tensor.bytes); + std::vector(tensorData, tensorData + inputTensor.bytes); /* Check sizes. */ REQUIRE(vecResults.size() == sizeof(expectedResult)); -- cgit v1.2.1