summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 09:08:21 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 09:08:21 +0100
commitc291144b7f08c21d08cdaf79cc64dc420ca70070 (patch)
tree1b91c38f7dd479a0c13772a1e1da52079d06237c /source
parentb1904b11d15da48c7ead4e6bb85c3e671956ab03 (diff)
downloadml-embedded-evaluation-kit-c291144b7f08c21d08cdaf79cc64dc420ca70070.tar.gz
MLECO-3077: Add ASR use case API
* Minor adjustments to doc strings in KWS * Remove unused score threshold in KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: Ie1c5bf6f7bdbebb853b6a10cb7ba1c4a1d9a76c9
Diffstat (limited to 'source')
-rw-r--r--source/use_case/asr/include/AsrResult.hpp2
-rw-r--r--source/use_case/asr/include/Wav2LetterModel.hpp3
-rw-r--r--source/use_case/asr/include/Wav2LetterPostprocess.hpp115
-rw-r--r--source/use_case/asr/include/Wav2LetterPreprocess.hpp96
-rw-r--r--source/use_case/asr/src/MainLoop.cc85
-rw-r--r--source/use_case/asr/src/UseCaseHandler.cc166
-rw-r--r--source/use_case/asr/src/Wav2LetterMfcc.cc1
-rw-r--r--source/use_case/asr/src/Wav2LetterModel.cc1
-rw-r--r--source/use_case/asr/src/Wav2LetterPostprocess.cc153
-rw-r--r--source/use_case/asr/src/Wav2LetterPreprocess.cc106
-rw-r--r--source/use_case/kws/include/KwsProcessing.hpp23
-rw-r--r--source/use_case/kws/src/KwsProcessing.cc5
-rw-r--r--source/use_case/kws/src/UseCaseHandler.cc4
13 files changed, 329 insertions, 431 deletions
diff --git a/source/use_case/asr/include/AsrResult.hpp b/source/use_case/asr/include/AsrResult.hpp
index b12ed7d..ed826d0 100644
--- a/source/use_case/asr/include/AsrResult.hpp
+++ b/source/use_case/asr/include/AsrResult.hpp
@@ -25,7 +25,7 @@ namespace arm {
namespace app {
namespace asr {
- using ResultVec = std::vector < arm::app::ClassificationResult >;
+ using ResultVec = std::vector<arm::app::ClassificationResult>;
/* Structure for holding ASR result. */
class AsrResult {
diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp
index 55395b9..895df2b 100644
--- a/source/use_case/asr/include/Wav2LetterModel.hpp
+++ b/source/use_case/asr/include/Wav2LetterModel.hpp
@@ -36,6 +36,9 @@ namespace app {
static constexpr uint32_t ms_outputRowsIdx = 2;
static constexpr uint32_t ms_outputColsIdx = 3;
+ static constexpr uint32_t ms_blankTokenIdx = 28;
+ static constexpr uint32_t ms_numMfccFeatures = 13;
+
protected:
/** @brief Gets the reference to op resolver interface class. */
const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/use_case/asr/include/Wav2LetterPostprocess.hpp
index 29eb548..45defa5 100644
--- a/source/use_case/asr/include/Wav2LetterPostprocess.hpp
+++ b/source/use_case/asr/include/Wav2LetterPostprocess.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,93 +17,90 @@
#ifndef ASR_WAV2LETTER_POSTPROCESS_HPP
#define ASR_WAV2LETTER_POSTPROCESS_HPP
-#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */
+#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */
+#include "BaseProcessing.hpp"
+#include "AsrClassifier.hpp"
+#include "AsrResult.hpp"
#include "log_macros.h"
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
/**
* @brief Helper class to manage tensor post-processing for "wav2letter"
* output.
*/
- class Postprocess {
+ class ASRPostProcess : public BasePostProcess {
public:
+ bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */
+
/**
- * @brief Constructor
- * @param[in] contextLen Left and right context length for
- * output tensor.
- * @param[in] innerLen This is the length of the section
- * between left and right context.
- * @param[in] blankTokenIdx Blank token index.
+ * @brief Constructor
+ * @param[in] outputTensor Pointer to the output Tensor.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] result Vector of classification results to store decoded outputs.
+ * @param[in] outputContextLen Left/right context length for output tensor.
+ * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes.
+ * @param[in] reductionAxis The axis that the logits of each time step is on.
**/
- Postprocess(uint32_t contextLen,
- uint32_t innerLen,
- uint32_t blankTokenIdx);
-
- Postprocess() = delete;
- ~Postprocess() = default;
+ ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor,
+ const std::vector<std::string>& labels, asr::ResultVec& result,
+ uint32_t outputContextLen,
+ uint32_t blankTokenIdx, uint32_t reductionAxis);
/**
- * @brief Erases the required part of the tensor based
- * on context lengths set up during initialisation.
- * @param[in] tensor Pointer to the tensor.
- * @param[in] axisIdx Index of the axis on which erase is
- * performed.
- * @param[in] lastIteration Flag to signal this is the
- * last iteration in which case
- * the right context is preserved.
- * @return true if successful, false otherwise.
- */
- bool Invoke(TfLiteTensor* tensor,
- uint32_t axisIdx,
- bool lastIteration = false);
+ * @brief Should perform post-processing of the result of inference then
+ * populate ASR result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPostProcess() override;
+
+ /** @brief Gets the output inner length for post-processing. */
+ static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen);
+
+ /** @brief Gets the output context length (left/right) for post-processing. */
+ static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen);
+
+ /** @brief Gets the number of feature vectors to be computed. */
+ static uint32_t GetNumFeatureVectors(const Model& model);
private:
- uint32_t m_contextLen; /* lengths of left and right contexts. */
- uint32_t m_innerLen; /* Length of inner context. */
- uint32_t m_totalLen; /* Total length of the required axis. */
- uint32_t m_countIterations; /* Current number of iterations. */
- uint32_t m_blankTokenIdx; /* Index of the labels blank token. */
+ AsrClassifier& m_classifier; /* ASR Classifier object. */
+ TfLiteTensor* m_outputTensor; /* Model output tensor. */
+ const std::vector<std::string>& m_labels; /* ASR Labels. */
+ asr::ResultVec & m_results; /* Results vector for a single inference. */
+ uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */
+ uint32_t m_outputInnerLen; /* Length of output inner context. */
+ uint32_t m_totalLen; /* Total length of the required axis. */
+ uint32_t m_countIterations; /* Current number of iterations. */
+ uint32_t m_blankTokenIdx; /* Index of the labels blank token. */
+ uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */
+
/**
- * @brief Checks if the tensor and axis index are valid
- * inputs to the object - based on how it has been
- * initialised.
- * @return true if valid, false otherwise.
+ * @brief Checks if the tensor and axis index are valid
+ * inputs to the object - based on how it has been initialised.
+ * @return true if valid, false otherwise.
*/
bool IsInputValid(TfLiteTensor* tensor,
- const uint32_t axisIdx) const;
+ uint32_t axisIdx) const;
/**
- * @brief Gets the tensor data element size in bytes based
- * on the tensor type.
- * @return Size in bytes, 0 if not supported.
+ * @brief Gets the tensor data element size in bytes based
+ * on the tensor type.
+ * @return Size in bytes, 0 if not supported.
*/
static uint32_t GetTensorElementSize(TfLiteTensor* tensor);
/**
- * @brief Erases sections from the data assuming row-wise
- * arrangement along the context axis.
- * @return true if successful, false otherwise.
+ * @brief Erases sections from the data assuming row-wise
+ * arrangement along the context axis.
+ * @return true if successful, false otherwise.
*/
bool EraseSectionsRowWise(uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration);
-
- /**
- * @brief Erases sections from the data assuming col-wise
- * arrangement along the context axis.
- * @return true if successful, false otherwise.
- */
- static bool EraseSectionsColWise(const uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration);
+ uint32_t strideSzBytes,
+ bool lastIteration);
};
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */
diff --git a/source/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/use_case/asr/include/Wav2LetterPreprocess.hpp
index 13d1589..8c12b3d 100644
--- a/source/use_case/asr/include/Wav2LetterPreprocess.hpp
+++ b/source/use_case/asr/include/Wav2LetterPreprocess.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,49 +21,44 @@
#include "Wav2LetterMfcc.hpp"
#include "AudioUtils.hpp"
#include "DataStructures.hpp"
+#include "BaseProcessing.hpp"
#include "log_macros.h"
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
/* Class to facilitate pre-processing calculation for Wav2Letter model
* for ASR. */
- using AudioWindow = SlidingWindow <const int16_t>;
+ using AudioWindow = audio::SlidingWindow<const int16_t>;
- class Preprocess {
+ class ASRPreProcess : public BasePreProcess {
public:
/**
* @brief Constructor.
- * @param[in] numMfccFeatures Number of MFCC features per window.
- * @param[in] windowLen Number of elements in a window.
- * @param[in] windowStride Stride (in number of elements) for
- * moving the window.
- * @param[in] numMfccVectors Number of MFCC vectors per window.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] numMfccFeatures Number of MFCC features per window.
+ * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window.
+ * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window.
+ * @param[in] mfccWindowStride Number of MFCC vectors that need to be calculated
+ * for an inference.
*/
- Preprocess(
- uint32_t numMfccFeatures,
- uint32_t windowLen,
- uint32_t windowStride,
- uint32_t numMfccVectors);
- Preprocess() = delete;
- ~Preprocess() = default;
+ ASRPreProcess(TfLiteTensor* inputTensor,
+ uint32_t numMfccFeatures,
+ uint32_t audioWindowLen,
+ uint32_t mfccWindowLen,
+ uint32_t mfccWindowStride);
/**
* @brief Calculates the features required from audio data. This
* includes MFCC, first and second order deltas,
* normalisation and finally, quantisation. The tensor is
- * populated with feature from a given window placed along
+ * populated with features from a given window placed along
* in a single row.
* @param[in] audioData Pointer to the first element of audio data.
* @param[in] audioDataLen Number of elements in the audio data.
- * @param[in] tensor Tensor to be populated.
* @return true if successful, false in case of error.
*/
- bool Invoke(const int16_t * audioData,
- uint32_t audioDataLen,
- TfLiteTensor * tensor);
+ bool DoPreProcess(const void* audioData, size_t audioDataLen) override;
protected:
/**
@@ -80,32 +75,16 @@ namespace asr {
Array2d<float>& delta2);
/**
- * @brief Given a 2D vector of floats, computes the mean.
- * @param[in] vec Vctor of vector of floats.
- * @return Mean value.
- */
- static float GetMean(Array2d<float>& vec);
-
- /**
- * @brief Given a 2D vector of floats, computes the stddev.
- * @param[in] vec Vector of vector of floats.
- * @param[in] mean Mean value of the vector passed in.
- * @return stddev value.
- */
- static float GetStdDev(Array2d<float>& vec,
- const float mean);
-
- /**
- * @brief Given a 2D vector of floats, normalises it using
- * the mean and the stddev.
+ * @brief Given a 2D vector of floats, rescale it to have mean of 0 and
+ * standard deviation of 1.
* @param[in,out] vec Vector of vector of floats.
*/
- static void NormaliseVec(Array2d<float>& vec);
+ static void StandardizeVecF32(Array2d<float>& vec);
/**
- * @brief Normalises the MFCC and delta buffers.
+ * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1.
*/
- void Normalise();
+ void Standarize();
/**
* @brief Given the quantisation and data type limits, computes
@@ -139,7 +118,7 @@ namespace asr {
*/
template <typename T>
bool Quantise(
- T * outputBuf,
+ T* outputBuf,
const uint32_t outputBufSz,
const float quantScale,
const int quantOffset)
@@ -160,15 +139,15 @@ namespace asr {
const float maxVal = std::numeric_limits<T>::max();
/* Need to transpose while copying and concatenating the tensor. */
- for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) {
+ for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) {
for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) {
- *outputBufMfcc++ = static_cast<T>(Preprocess::GetQuantElem(
+ *outputBufMfcc++ = static_cast<T>(ASRPreProcess::GetQuantElem(
this->m_mfccBuf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD1++ = static_cast<T>(Preprocess::GetQuantElem(
+ *outputBufD1++ = static_cast<T>(ASRPreProcess::GetQuantElem(
this->m_delta1Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD2++ = static_cast<T>(Preprocess::GetQuantElem(
+ *outputBufD2++ = static_cast<T>(ASRPreProcess::GetQuantElem(
this->m_delta2Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
}
@@ -181,23 +160,22 @@ namespace asr {
}
private:
- Wav2LetterMFCC m_mfcc; /* MFCC instance. */
+ audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */
+ TfLiteTensor* m_inputTensor; /* Model input tensor. */
/* Actual buffers to be populated. */
- Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */
- Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */
- Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */
+ Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */
+ Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */
+ Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */
- uint32_t m_windowLen; /* Window length for MFCC. */
- uint32_t m_windowStride; /* Window stride len for MFCC. */
- uint32_t m_numMfccFeats; /* Number of MFCC features per window. */
- uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */
- AudioWindow m_window; /* Sliding window. */
+ uint32_t m_mfccWindowLen; /* Window length for MFCC. */
+ uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */
+ uint32_t m_numMfccFeats; /* Number of MFCC features per window. */
+ uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */
+ AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */
};
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */
diff --git a/source/use_case/asr/src/MainLoop.cc b/source/use_case/asr/src/MainLoop.cc
index 51b0b18..a1a9540 100644
--- a/source/use_case/asr/src/MainLoop.cc
+++ b/source/use_case/asr/src/MainLoop.cc
@@ -14,15 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "hal.h" /* Brings in platform definitions. */
#include "Labels.hpp" /* For label strings. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
#include "Wav2LetterModel.hpp" /* Model class for running inference. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
#include "AsrClassifier.hpp" /* Classifier. */
#include "InputFiles.hpp" /* Generated audio clip header. */
-#include "Wav2LetterPreprocess.hpp" /* Pre-processing class. */
-#include "Wav2LetterPostprocess.hpp" /* Post-processing class. */
#include "log_macros.h"
enum opcodes
@@ -48,23 +45,9 @@ static void DisplayMenu()
fflush(stdout);
}
-/** @brief Verify input and output tensor are of certain min dimensions. */
+/** @brief Verify input and output tensor are of certain min dimensions. */
static bool VerifyTensorDimensions(const arm::app::Model& model);
-/** @brief Gets the number of MFCC features for a single window. */
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model);
-
-/** @brief Gets the number of MFCC feature vectors to be computed. */
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model);
-
-/** @brief Gets the output context length (left and right) for post-processing. */
-static uint32_t GetOutputContextLen(const arm::app::Model& model,
- uint32_t inputCtxLen);
-
-/** @brief Gets the output inner length for post-processing. */
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
- uint32_t outputCtxLen);
-
void main_loop()
{
arm::app::Wav2LetterModel model; /* Model wrapper object. */
@@ -78,21 +61,6 @@ void main_loop()
return;
}
- /* Initialise pre-processing. */
- arm::app::audio::asr::Preprocess prep(
- GetNumMfccFeatures(model),
- g_FrameLength,
- g_FrameStride,
- GetNumMfccFeatureVectors(model));
-
- /* Initialise post-processing. */
- const uint32_t outputCtxLen = GetOutputContextLen(model, g_ctxLen);
- const uint32_t blankTokenIdx = 28;
- arm::app::audio::asr::Postprocess postp(
- outputCtxLen,
- GetOutputInnerLen(model, outputCtxLen),
- blankTokenIdx);
-
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
std::vector <std::string> labels;
@@ -109,8 +77,6 @@ void main_loop()
caseContext.Set<uint32_t>("ctxLen", g_ctxLen); /* Left and right context length (MFCC feat vectors). */
caseContext.Set<const std::vector <std::string>&>("labels", labels);
caseContext.Set<arm::app::AsrClassifier&>("classifier", classifier);
- caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep);
- caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp);
bool executionSuccessful = true;
constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false;
@@ -184,52 +150,3 @@ static bool VerifyTensorDimensions(const arm::app::Model& model)
return true;
}
-
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model)
-{
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx];
- if (0 != inputCols % 3) {
- printf_err("Number of input columns is not a multiple of 3\n");
- }
- return std::max(inputCols/3, 0);
-}
-
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model)
-{
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
- return std::max(inputRows, 0);
-}
-
-static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen)
-{
- const uint32_t inputRows = GetNumMfccFeatureVectors(model);
- const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
- constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
-
- /* Check to make sure that the input tensor supports the above
- * context and inner lengths. */
- if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
- printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
- inputCtxLen);
- return 0;
- }
-
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
-
- const float tensorColRatio = static_cast<float>(inputRows)/
- static_cast<float>(outputRows);
-
- return std::round(static_cast<float>(inputCtxLen)/tensorColRatio);
-}
-
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
- const uint32_t outputCtxLen)
-{
- constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
- return (outputRows - (2 * outputCtxLen));
-}
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc
index 420f725..7fe959b 100644
--- a/source/use_case/asr/src/UseCaseHandler.cc
+++ b/source/use_case/asr/src/UseCaseHandler.cc
@@ -20,7 +20,6 @@
#include "AsrClassifier.hpp"
#include "Wav2LetterModel.hpp"
#include "hal.h"
-#include "Wav2LetterMfcc.hpp"
#include "AudioUtils.hpp"
#include "ImageUtils.hpp"
#include "UseCaseCommonUtils.hpp"
@@ -34,68 +33,63 @@ namespace arm {
namespace app {
/**
- * @brief Presents inference results using the data presentation
- * object.
- * @param[in] results Vector of classification results to be displayed.
+ * @brief Presents ASR inference results.
+ * @param[in] results Vector of ASR classification results to be displayed.
* @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results);
+ static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results);
- /* Audio inference classification handler. */
+ /* ASR inference handler. */
bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
- constexpr uint32_t dataPsnTxtInfStartX = 20;
- constexpr uint32_t dataPsnTxtInfStartY = 40;
-
- hal_lcd_clear(COLOR_BLACK);
-
+ auto& model = ctx.Get<Model&>("model");
auto& profiler = ctx.Get<Profiler&>("profiler");
-
+ auto mfccFrameLen = ctx.Get<uint32_t>("frameLength");
+ auto mfccFrameStride = ctx.Get<uint32_t>("frameStride");
+ auto scoreThreshold = ctx.Get<float>("scoreThreshold");
+ auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
/* If the request has a valid size, set the audio index. */
if (clipIndex < NUMBER_OF_FILES) {
if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
return false;
}
}
+ auto initialClipIdx = ctx.Get<uint32_t>("clipIndex");
+ constexpr uint32_t dataPsnTxtInfStartX = 20;
+ constexpr uint32_t dataPsnTxtInfStartY = 40;
- /* Get model reference. */
- auto& model = ctx.Get<Model&>("model");
if (!model.IsInited()) {
printf_err("Model is not initialised! Terminating processing.\n");
return false;
}
- /* Get score threshold to be applied for the classifier (post-inference). */
- auto scoreThreshold = ctx.Get<float>("scoreThreshold");
-
- /* Get tensors. Dimensions of the tensor should have been verified by
+ /* Get input shape. Dimensions of the tensor should have been verified by
* the callee. */
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
+ TfLiteIntArray* inputShape = model.GetInputShape(0);
- /* Populate MFCC related parameters. */
- auto mfccParamsWinLen = ctx.Get<uint32_t>("frameLength");
- auto mfccParamsWinStride = ctx.Get<uint32_t>("frameStride");
-
- /* Populate ASR inference context and inner lengths for input. */
- auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
- const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
+ const uint32_t inputRowsSize = inputShape->data[Wav2LetterModel::ms_inputRowsIdx];
+ const uint32_t inputInnerLen = inputRowsSize - (2 * inputCtxLen);
/* Audio data stride corresponds to inputInnerLen feature vectors. */
- const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen);
- const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride;
- const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
+ const uint32_t audioDataWindowLen = (inputRowsSize - 1) * mfccFrameStride + (mfccFrameLen);
+ const uint32_t audioDataWindowStride = inputInnerLen * mfccFrameStride;
+
+ /* NOTE: This is only used for time stamp calculation. */
+ const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
- /* Get pre/post-processing objects. */
- auto& prep = ctx.Get<audio::asr::Preprocess&>("preprocess");
- auto& postp = ctx.Get<audio::asr::Postprocess&>("postprocess");
+ /* Set up pre and post-processing objects. */
+ ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures,
+ inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride);
- /* Set default reduction axis for post-processing. */
- const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
+ std::vector<ClassificationResult> singleInfResult;
+ const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen);
+ ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"),
+ model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"),
+ singleInfResult, outputCtxLen,
+ Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
+ );
- /* Audio clip start index. */
- auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
+ UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model);
/* Loop to process audio clips. */
do {
@@ -109,44 +103,41 @@ namespace app {
const uint32_t audioArrSize = get_audio_array_size(currentIndex);
if (!audioArr) {
- printf_err("Invalid audio array pointer\n");
+ printf_err("Invalid audio array pointer.\n");
return false;
}
- /* Audio clip must have enough samples to produce 1 MFCC feature. */
- if (audioArrSize < mfccParamsWinLen) {
+ /* Audio clip needs enough samples to produce at least 1 MFCC feature. */
+ if (audioArrSize < mfccFrameLen) {
printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
- mfccParamsWinLen);
+ mfccFrameLen);
return false;
}
- /* Initialise an audio slider. */
+ /* Creating a sliding window through the whole audio clip. */
auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
- audioArr,
- audioArrSize,
- audioParamsWinLen,
- audioParamsWinStride);
+ audioArr, audioArrSize,
+ audioDataWindowLen, audioDataWindowStride);
- /* Declare a container for results. */
- std::vector<arm::app::asr::AsrResult> results;
+ /* Declare a container for final results. */
+ std::vector<asr::AsrResult> finalResults;
/* Display message on the LCD - inference running. */
std::string str_inf{"Running inference... "};
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
get_filename(currentIndex));
- size_t inferenceWindowLen = audioParamsWinLen;
+ size_t inferenceWindowLen = audioDataWindowLen;
/* Start sliding through audio clip. */
while (audioDataSlider.HasNext()) {
- /* If not enough audio see how much can be sent for processing. */
+ /* If not enough audio, see how much can be sent for processing. */
size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
- if (nextStartIndex + audioParamsWinLen > audioArrSize) {
+ if (nextStartIndex + audioDataWindowLen > audioArrSize) {
inferenceWindowLen = audioArrSize - nextStartIndex;
}
@@ -155,46 +146,40 @@ namespace app {
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
- /* Calculate MFCCs, deltas and populate the input tensor. */
- prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor);
+ /* Run the pre-processing, inference and post-processing. */
+ runner.PreProcess(inferenceWindow, inferenceWindowLen);
- /* Run inference over this audio clip sliding window. */
- if (!RunInference(model, profiler)) {
+ profiler.StartProfiling("Inference");
+ if (!runner.RunInference()) {
return false;
}
+ profiler.StopProfiling();
- /* Post-process. */
- postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext());
-
- /* Get results. */
- std::vector<ClassificationResult> classificationResult;
- auto& classifier = ctx.Get<AsrClassifier&>("classifier");
- classifier.GetClassificationResults(
- outputTensor, classificationResult,
- ctx.Get<std::vector<std::string>&>("labels"), 1);
+ postProcess.m_lastIteration = !audioDataSlider.HasNext();
+ if (!runner.PostProcess()) {
+ return false;
+ }
- results.emplace_back(asr::AsrResult(classificationResult,
- (audioDataSlider.Index() *
- audioParamsSecondsPerSample *
- audioParamsWinStride),
- audioDataSlider.Index(), scoreThreshold));
+ /* Add results from this window to our final results vector. */
+ finalResults.emplace_back(asr::AsrResult(singleInfResult,
+ (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride),
+ audioDataSlider.Index(), scoreThreshold));
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(outputTensor,
- outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+ armDumpTensor(outputTensor,
+ outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
#endif /* VERIFY_TEST_OUTPUT */
-
- }
+ } /* while (audioDataSlider.HasNext()) */
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
- ctx.Set<std::vector<arm::app::asr::AsrResult>>("results", results);
+ ctx.Set<std::vector<asr::AsrResult>>("results", finalResults);
- if (!PresentInferenceResult(results)) {
+ if (!PresentInferenceResult(finalResults)) {
return false;
}
@@ -202,13 +187,13 @@ namespace app {
IncrementAppCtxIfmIdx(ctx,"clipIndex");
- } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
+ } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx);
return true;
}
- static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results)
+ static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results)
{
constexpr uint32_t dataPsnTxtStartX1 = 20;
constexpr uint32_t dataPsnTxtStartY1 = 60;
@@ -219,15 +204,15 @@ namespace app {
info("Final results:\n");
info("Total number of inferences: %zu\n", results.size());
/* Results from multiple inferences should be combined before processing. */
- std::vector<arm::app::ClassificationResult> combinedResults;
- for (auto& result : results) {
+ std::vector<ClassificationResult> combinedResults;
+ for (const auto& result : results) {
combinedResults.insert(combinedResults.end(),
result.m_resultVec.begin(),
result.m_resultVec.end());
}
/* Get each inference result string using the decoder. */
- for (const auto & result : results) {
+ for (const auto& result : results) {
std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n",
@@ -238,10 +223,9 @@ namespace app {
/* Get the decoded result for the combined result. */
std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
- hal_lcd_display_text(
- finalResultStr.c_str(), finalResultStr.size(),
- dataPsnTxtStartX1, dataPsnTxtStartY1,
- allow_multiple_lines);
+ hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
+ dataPsnTxtStartX1, dataPsnTxtStartY1,
+ allow_multiple_lines);
info("Complete recognition: %s\n", finalResultStr.c_str());
return true;
diff --git a/source/use_case/asr/src/Wav2LetterMfcc.cc b/source/use_case/asr/src/Wav2LetterMfcc.cc
index 1bcaa66..bb29b0f 100644
--- a/source/use_case/asr/src/Wav2LetterMfcc.cc
+++ b/source/use_case/asr/src/Wav2LetterMfcc.cc
@@ -15,6 +15,7 @@
* limitations under the License.
*/
#include "Wav2LetterMfcc.hpp"
+
#include "PlatformMath.hpp"
#include "log_macros.h"
diff --git a/source/use_case/asr/src/Wav2LetterModel.cc b/source/use_case/asr/src/Wav2LetterModel.cc
index 766bce9..8b38f4f 100644
--- a/source/use_case/asr/src/Wav2LetterModel.cc
+++ b/source/use_case/asr/src/Wav2LetterModel.cc
@@ -15,6 +15,7 @@
* limitations under the License.
*/
#include "Wav2LetterModel.hpp"
+
#include "log_macros.h"
diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc
index 0392061..e3e1999 100644
--- a/source/use_case/asr/src/Wav2LetterPostprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,67 +15,71 @@
* limitations under the License.
*/
#include "Wav2LetterPostprocess.hpp"
+
#include "Wav2LetterModel.hpp"
#include "log_macros.h"
+#include <cmath>
+
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
-
- Postprocess::Postprocess(const uint32_t contextLen,
- const uint32_t innerLen,
- const uint32_t blankTokenIdx)
- : m_contextLen(contextLen),
- m_innerLen(innerLen),
- m_totalLen(2 * this->m_contextLen + this->m_innerLen),
+
+ ASRPostProcess::ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor,
+ const std::vector<std::string>& labels, std::vector<ClassificationResult>& results,
+ const uint32_t outputContextLen,
+ const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx
+ ):
+ m_classifier(classifier),
+ m_outputTensor(outputTensor),
+ m_labels{labels},
+ m_results(results),
+ m_outputContextLen(outputContextLen),
m_countIterations(0),
- m_blankTokenIdx(blankTokenIdx)
- {}
+ m_blankTokenIdx(blankTokenIdx),
+ m_reductionAxisIdx(reductionAxisIdx)
+ {
+ this->m_outputInnerLen = ASRPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
+ this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen);
+ }
- bool Postprocess::Invoke(TfLiteTensor* tensor,
- const uint32_t axisIdx,
- const bool lastIteration)
+ bool ASRPostProcess::DoPostProcess()
{
/* Basic checks. */
- if (!this->IsInputValid(tensor, axisIdx)) {
+ if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) {
return false;
}
/* Irrespective of tensor type, we use unsigned "byte" */
- uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
- const uint32_t elemSz = this->GetTensorElementSize(tensor);
+ auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor);
+ const uint32_t elemSz = ASRPostProcess::GetTensorElementSize(this->m_outputTensor);
/* Other sanity checks. */
if (0 == elemSz) {
printf_err("Tensor type not supported for post processing\n");
return false;
- } else if (elemSz * this->m_totalLen > tensor->bytes) {
+ } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) {
printf_err("Insufficient number of tensor bytes\n");
return false;
}
/* Which axis do we need to process? */
- switch (axisIdx) {
- case arm::app::Wav2LetterModel::ms_outputRowsIdx:
- return this->EraseSectionsRowWise(ptrData,
- elemSz *
- tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
- lastIteration);
- case arm::app::Wav2LetterModel::ms_outputColsIdx:
- return this->EraseSectionsColWise(ptrData,
- elemSz *
- tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx],
- lastIteration);
+ switch (this->m_reductionAxisIdx) {
+ case Wav2LetterModel::ms_outputRowsIdx:
+ this->EraseSectionsRowWise(
+ ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx],
+ this->m_lastIteration);
+ break;
default:
- printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx);
+ printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx);
+ return false;
}
+ this->m_classifier.GetClassificationResults(this->m_outputTensor,
+ this->m_results, this->m_labels, 1);
- return false;
+ return true;
}
- bool Postprocess::IsInputValid(TfLiteTensor* tensor,
- const uint32_t axisIdx) const
+ bool ASRPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
{
if (nullptr == tensor) {
return false;
@@ -89,15 +93,15 @@ namespace asr {
if (static_cast<int>(this->m_totalLen) !=
tensor->dims->data[axisIdx]) {
- printf_err("Unexpected tensor dimension for axis %d, \n",
- tensor->dims->data[axisIdx]);
+ printf_err("Unexpected tensor dimension for axis %d, got %d, \n",
+ axisIdx, tensor->dims->data[axisIdx]);
return false;
}
return true;
}
- uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor)
+ uint32_t ASRPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
{
switch(tensor->type) {
case kTfLiteUInt8:
@@ -116,30 +120,30 @@ namespace asr {
return 0;
}
- bool Postprocess::EraseSectionsRowWise(
- uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration)
+ bool ASRPostProcess::EraseSectionsRowWise(
+ uint8_t* ptrData,
+ const uint32_t strideSzBytes,
+ const bool lastIteration)
{
/* In this case, the "zero-ing" is quite simple as the region
* to be zeroed sits in contiguous memory (row-major). */
- const uint32_t eraseLen = strideSzBytes * this->m_contextLen;
+ const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen;
/* Erase left context? */
if (this->m_countIterations > 0) {
/* Set output of each classification window to the blank token. */
std::memset(ptrData, 0, eraseLen);
- for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+ for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
}
}
/* Erase right context? */
if (false == lastIteration) {
- uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen));
+ uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen));
/* Set output of each classification window to the blank token. */
std::memset(rightCtxPtr, 0, eraseLen);
- for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+ for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
}
}
@@ -153,19 +157,56 @@ namespace asr {
return true;
}
- bool Postprocess::EraseSectionsColWise(
- const uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration)
+ uint32_t ASRPostProcess::GetNumFeatureVectors(const Model& model)
+ {
+ TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0);
+ if (inputRows == 0) {
+ printf_err("Error getting number of input rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_inputRowsIdx);
+ }
+ return inputRows;
+ }
+
+ uint32_t ASRPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
{
- /* Not implemented. */
- UNUSED(ptrData);
- UNUSED(strideSzBytes);
- UNUSED(lastIteration);
- return false;
+ const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0);
+ if (outputRows == 0) {
+ printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_outputRowsIdx);
+ }
+ int innerLen = (outputRows - (2 * outputCtxLen));
+
+ return std::max(innerLen, 0);
+ }
+
+ uint32_t ASRPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
+ {
+ const uint32_t inputRows = ASRPostProcess::GetNumFeatureVectors(model);
+ const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
+ constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx;
+
+ /* Check to make sure that the input tensor supports the above
+ * context and inner lengths. */
+ if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
+ printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
+ inputCtxLen);
+ return 0;
+ }
+
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+ const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
+ if (outputRows == 0) {
+ printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_outputRowsIdx);
+ return 0;
+ }
+
+ const float inOutRowRatio = static_cast<float>(inputRows) /
+ static_cast<float>(outputRows);
+
+ return std::round(static_cast<float>(inputCtxLen) / inOutRowRatio);
}
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/asr/src/Wav2LetterPreprocess.cc b/source/use_case/asr/src/Wav2LetterPreprocess.cc
index e5ac3ca..590d08a 100644
--- a/source/use_case/asr/src/Wav2LetterPreprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPreprocess.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,37 +24,31 @@
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
-
- Preprocess::Preprocess(
- const uint32_t numMfccFeatures,
- const uint32_t windowLen,
- const uint32_t windowStride,
- const uint32_t numMfccVectors):
- m_mfcc(numMfccFeatures, windowLen),
- m_mfccBuf(numMfccFeatures, numMfccVectors),
- m_delta1Buf(numMfccFeatures, numMfccVectors),
- m_delta2Buf(numMfccFeatures, numMfccVectors),
- m_windowLen(windowLen),
- m_windowStride(windowStride),
+
+ ASRPreProcess::ASRPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures,
+ const uint32_t numFeatureFrames, const uint32_t mfccWindowLen,
+ const uint32_t mfccWindowStride
+ ):
+ m_mfcc(numMfccFeatures, mfccWindowLen),
+ m_inputTensor(inputTensor),
+ m_mfccBuf(numMfccFeatures, numFeatureFrames),
+ m_delta1Buf(numMfccFeatures, numFeatureFrames),
+ m_delta2Buf(numMfccFeatures, numFeatureFrames),
+ m_mfccWindowLen(mfccWindowLen),
+ m_mfccWindowStride(mfccWindowStride),
m_numMfccFeats(numMfccFeatures),
- m_numFeatVectors(numMfccVectors),
- m_window()
+ m_numFeatureFrames(numFeatureFrames)
{
- if (numMfccFeatures > 0 && windowLen > 0) {
+ if (numMfccFeatures > 0 && mfccWindowLen > 0) {
this->m_mfcc.Init();
}
}
- bool Preprocess::Invoke(
- const int16_t* audioData,
- const uint32_t audioDataLen,
- TfLiteTensor* tensor)
+ bool ASRPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen)
{
- this->m_window = SlidingWindow<const int16_t>(
- audioData, audioDataLen,
- this->m_windowLen, this->m_windowStride);
+ this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(
+ static_cast<const int16_t*>(audioData), audioDataLen,
+ this->m_mfccWindowLen, this->m_mfccWindowStride);
uint32_t mfccBufIdx = 0;
@@ -62,12 +56,12 @@ namespace asr {
std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f);
std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f);
- /* While we can slide over the window. */
- while (this->m_window.HasNext()) {
- const int16_t* mfccWindow = this->m_window.Next();
+ /* While we can slide over the audio. */
+ while (this->m_mfccSlidingWindow.HasNext()) {
+ const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next();
auto mfccAudioData = std::vector<int16_t>(
mfccWindow,
- mfccWindow + this->m_windowLen);
+ mfccWindow + this->m_mfccWindowLen);
auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData);
for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) {
this->m_mfccBuf(i, mfccBufIdx) = mfcc[i];
@@ -76,11 +70,11 @@ namespace asr {
}
/* Pad MFCC if needed by adding MFCC for zeros. */
- if (mfccBufIdx != this->m_numFeatVectors) {
- std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_windowLen, 0);
+ if (mfccBufIdx != this->m_numFeatureFrames) {
+ std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_mfccWindowLen, 0);
std::vector<float> mfccZeros = this->m_mfcc.MfccCompute(zerosWindow);
- while (mfccBufIdx != this->m_numFeatVectors) {
+ while (mfccBufIdx != this->m_numFeatureFrames) {
memcpy(&this->m_mfccBuf(0, mfccBufIdx),
mfccZeros.data(), sizeof(float) * m_numMfccFeats);
++mfccBufIdx;
@@ -88,39 +82,37 @@ namespace asr {
}
/* Compute first and second order deltas from MFCCs. */
- Preprocess::ComputeDeltas(this->m_mfccBuf,
- this->m_delta1Buf,
- this->m_delta2Buf);
+ ASRPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf);
- /* Normalise. */
- this->Normalise();
+ /* Standardize calculated features. */
+ this->Standarize();
/* Quantise. */
- QuantParams quantParams = GetTensorQuantParams(tensor);
+ QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor);
if (0 == quantParams.scale) {
printf_err("Quantisation scale can't be 0\n");
return false;
}
- switch(tensor->type) {
+ switch(this->m_inputTensor->type) {
case kTfLiteUInt8:
return this->Quantise<uint8_t>(
- tflite::GetTensorData<uint8_t>(tensor), tensor->bytes,
+ tflite::GetTensorData<uint8_t>(this->m_inputTensor), this->m_inputTensor->bytes,
quantParams.scale, quantParams.offset);
case kTfLiteInt8:
return this->Quantise<int8_t>(
- tflite::GetTensorData<int8_t>(tensor), tensor->bytes,
+ tflite::GetTensorData<int8_t>(this->m_inputTensor), this->m_inputTensor->bytes,
quantParams.scale, quantParams.offset);
default:
printf_err("Unsupported tensor type %s\n",
- TfLiteTypeGetName(tensor->type));
+ TfLiteTypeGetName(this->m_inputTensor->type));
}
return false;
}
- bool Preprocess::ComputeDeltas(Array2d<float>& mfcc,
+ bool ASRPreProcess::ComputeDeltas(Array2d<float>& mfcc,
Array2d<float>& delta1,
Array2d<float>& delta2)
{
@@ -175,20 +167,10 @@ namespace asr {
return true;
}
- float Preprocess::GetMean(Array2d<float>& vec)
- {
- return math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
- }
-
- float Preprocess::GetStdDev(Array2d<float>& vec, const float mean)
- {
- return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
- }
-
- void Preprocess::NormaliseVec(Array2d<float>& vec)
+ void ASRPreProcess::StandardizeVecF32(Array2d<float>& vec)
{
- auto mean = Preprocess::GetMean(vec);
- auto stddev = Preprocess::GetStdDev(vec, mean);
+ auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
+ auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
debug("Mean: %f, Stddev: %f\n", mean, stddev);
if (stddev == 0) {
@@ -204,14 +186,14 @@ namespace asr {
}
}
- void Preprocess::Normalise()
+ void ASRPreProcess::Standarize()
{
- Preprocess::NormaliseVec(this->m_mfccBuf);
- Preprocess::NormaliseVec(this->m_delta1Buf);
- Preprocess::NormaliseVec(this->m_delta2Buf);
+ ASRPreProcess::StandardizeVecF32(this->m_mfccBuf);
+ ASRPreProcess::StandardizeVecF32(this->m_delta1Buf);
+ ASRPreProcess::StandardizeVecF32(this->m_delta2Buf);
}
- float Preprocess::GetQuantElem(
+ float ASRPreProcess::GetQuantElem(
const float elem,
const float quantScale,
const int quantOffset,
@@ -222,7 +204,5 @@ namespace asr {
return std::min<float>(std::max<float>(val, minVal), maxVal);
}
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp
index abf20ab..ddf38c1 100644
--- a/source/use_case/kws/include/KwsProcessing.hpp
+++ b/source/use_case/kws/include/KwsProcessing.hpp
@@ -38,7 +38,7 @@ namespace app {
public:
/**
* @brief Constructor
- * @param[in] model Pointer to the the KWS Model object.
+ * @param[in] model Pointer to the KWS Model object.
* @param[in] numFeatures How many MFCC features to use.
* @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
* sliding a window through the audio sample.
@@ -107,24 +107,21 @@ namespace app {
std::vector<ClassificationResult>& m_results;
public:
- const float m_scoreThreshold;
/**
- * @brief Constructor
- * @param[in] classifier Classifier object used to get top N results from classification.
- * @param[in] model Pointer to the the Image classification Model object.
- * @param[in] labels Vector of string labels to identify each output of the model.
- * @param[in] results Vector of classification results to store decoded outputs.
- * @param[in] scoreThreshold Predicted model score must be larger than this value to be accepted.
+ * @brief Constructor
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] model Pointer to the KWS Model object.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] results Vector of classification results to store decoded outputs.
**/
KWSPostProcess(Classifier& classifier, Model* model,
const std::vector<std::string>& labels,
- std::vector<ClassificationResult>& results,
- float scoreThreshold);
+ std::vector<ClassificationResult>& results);
/**
- * @brief Should perform post-processing of the result of inference then populate
- * populate KWS result data for any later use.
- * @return true if successful, false otherwise.
+ * @brief Should perform post-processing of the result of inference then
+ * populate KWS result data for any later use.
+ * @return true if successful, false otherwise.
**/
bool DoPostProcess() override;
};
diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc
index b6b230c..14f9fce 100644
--- a/source/use_case/kws/src/KwsProcessing.cc
+++ b/source/use_case/kws/src/KwsProcessing.cc
@@ -197,11 +197,10 @@ namespace app {
KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model,
const std::vector<std::string>& labels,
- std::vector<ClassificationResult>& results, float scoreThreshold)
+ std::vector<ClassificationResult>& results)
:m_kwsClassifier{classifier},
m_labels{labels},
- m_results{results},
- m_scoreThreshold{scoreThreshold}
+ m_results{results}
{
if (!model->IsInited()) {
printf_err("Model is not initialised!.\n");
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index 350d34b..e73a2c3 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -93,7 +93,7 @@ namespace app {
std::vector<ClassificationResult> singleInfResult;
KWSPostProcess postprocess = KWSPostProcess(ctx.Get<KwsClassifier &>("classifier"), &model,
ctx.Get<std::vector<std::string>&>("labels"),
- singleInfResult, scoreThreshold);
+ singleInfResult);
UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
@@ -146,7 +146,7 @@ namespace app {
/* Add results from this window to our final results vector. */
finalResults.emplace_back(kws::KwsResult(singleInfResult,
audioDataSlider.Index() * secondsPerSample * preprocess.m_audioDataStride,
- audioDataSlider.Index(), postprocess.m_scoreThreshold));
+ audioDataSlider.Index(), scoreThreshold));
#if VERIFY_TEST_OUTPUT
TfLiteTensor* outputTensor = model.GetOutputTensor(0);