summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 09:08:21 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 09:08:21 +0100
commitc291144b7f08c21d08cdaf79cc64dc420ca70070 (patch)
tree1b91c38f7dd479a0c13772a1e1da52079d06237c
parentb1904b11d15da48c7ead4e6bb85c3e671956ab03 (diff)
downloadml-embedded-evaluation-kit-c291144b7f08c21d08cdaf79cc64dc420ca70070.tar.gz
MLECO-3077: Add ASR use case API
* Minor adjustments to doc strings in KWS * Remove unused score threshold in KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: Ie1c5bf6f7bdbebb853b6a10cb7ba1c4a1d9a76c9
-rw-r--r--source/use_case/asr/include/AsrResult.hpp2
-rw-r--r--source/use_case/asr/include/Wav2LetterModel.hpp3
-rw-r--r--source/use_case/asr/include/Wav2LetterPostprocess.hpp115
-rw-r--r--source/use_case/asr/include/Wav2LetterPreprocess.hpp96
-rw-r--r--source/use_case/asr/src/MainLoop.cc85
-rw-r--r--source/use_case/asr/src/UseCaseHandler.cc166
-rw-r--r--source/use_case/asr/src/Wav2LetterMfcc.cc1
-rw-r--r--source/use_case/asr/src/Wav2LetterModel.cc1
-rw-r--r--source/use_case/asr/src/Wav2LetterPostprocess.cc153
-rw-r--r--source/use_case/asr/src/Wav2LetterPreprocess.cc106
-rw-r--r--source/use_case/kws/include/KwsProcessing.hpp23
-rw-r--r--source/use_case/kws/src/KwsProcessing.cc5
-rw-r--r--source/use_case/kws/src/UseCaseHandler.cc4
-rw-r--r--tests/common/PlatformMathTests.cpp33
-rw-r--r--tests/use_case/asr/AsrFeaturesTests.cc52
-rw-r--r--tests/use_case/asr/Wav2LetterPostprocessingTest.cc124
-rw-r--r--tests/use_case/asr/Wav2LetterPreprocessingTest.cc120
17 files changed, 495 insertions, 594 deletions
diff --git a/source/use_case/asr/include/AsrResult.hpp b/source/use_case/asr/include/AsrResult.hpp
index b12ed7d..ed826d0 100644
--- a/source/use_case/asr/include/AsrResult.hpp
+++ b/source/use_case/asr/include/AsrResult.hpp
@@ -25,7 +25,7 @@ namespace arm {
namespace app {
namespace asr {
- using ResultVec = std::vector < arm::app::ClassificationResult >;
+ using ResultVec = std::vector<arm::app::ClassificationResult>;
/* Structure for holding ASR result. */
class AsrResult {
diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp
index 55395b9..895df2b 100644
--- a/source/use_case/asr/include/Wav2LetterModel.hpp
+++ b/source/use_case/asr/include/Wav2LetterModel.hpp
@@ -36,6 +36,9 @@ namespace app {
static constexpr uint32_t ms_outputRowsIdx = 2;
static constexpr uint32_t ms_outputColsIdx = 3;
+ static constexpr uint32_t ms_blankTokenIdx = 28;
+ static constexpr uint32_t ms_numMfccFeatures = 13;
+
protected:
/** @brief Gets the reference to op resolver interface class. */
const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/use_case/asr/include/Wav2LetterPostprocess.hpp
index 29eb548..45defa5 100644
--- a/source/use_case/asr/include/Wav2LetterPostprocess.hpp
+++ b/source/use_case/asr/include/Wav2LetterPostprocess.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,93 +17,90 @@
#ifndef ASR_WAV2LETTER_POSTPROCESS_HPP
#define ASR_WAV2LETTER_POSTPROCESS_HPP
-#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */
+#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */
+#include "BaseProcessing.hpp"
+#include "AsrClassifier.hpp"
+#include "AsrResult.hpp"
#include "log_macros.h"
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
/**
* @brief Helper class to manage tensor post-processing for "wav2letter"
* output.
*/
- class Postprocess {
+ class ASRPostProcess : public BasePostProcess {
public:
+ bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */
+
/**
- * @brief Constructor
- * @param[in] contextLen Left and right context length for
- * output tensor.
- * @param[in] innerLen This is the length of the section
- * between left and right context.
- * @param[in] blankTokenIdx Blank token index.
+ * @brief Constructor
+ * @param[in] outputTensor Pointer to the output Tensor.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] result Vector of classification results to store decoded outputs.
+ * @param[in] outputContextLen Left/right context length for output tensor.
+ * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes.
+ * @param[in] reductionAxis The axis that the logits of each time step is on.
**/
- Postprocess(uint32_t contextLen,
- uint32_t innerLen,
- uint32_t blankTokenIdx);
-
- Postprocess() = delete;
- ~Postprocess() = default;
+ ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor,
+ const std::vector<std::string>& labels, asr::ResultVec& result,
+ uint32_t outputContextLen,
+ uint32_t blankTokenIdx, uint32_t reductionAxis);
/**
- * @brief Erases the required part of the tensor based
- * on context lengths set up during initialisation.
- * @param[in] tensor Pointer to the tensor.
- * @param[in] axisIdx Index of the axis on which erase is
- * performed.
- * @param[in] lastIteration Flag to signal this is the
- * last iteration in which case
- * the right context is preserved.
- * @return true if successful, false otherwise.
- */
- bool Invoke(TfLiteTensor* tensor,
- uint32_t axisIdx,
- bool lastIteration = false);
+ * @brief Should perform post-processing of the result of inference then
+ * populate ASR result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPostProcess() override;
+
+ /** @brief Gets the output inner length for post-processing. */
+ static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen);
+
+ /** @brief Gets the output context length (left/right) for post-processing. */
+ static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen);
+
+ /** @brief Gets the number of feature vectors to be computed. */
+ static uint32_t GetNumFeatureVectors(const Model& model);
private:
- uint32_t m_contextLen; /* lengths of left and right contexts. */
- uint32_t m_innerLen; /* Length of inner context. */
- uint32_t m_totalLen; /* Total length of the required axis. */
- uint32_t m_countIterations; /* Current number of iterations. */
- uint32_t m_blankTokenIdx; /* Index of the labels blank token. */
+ AsrClassifier& m_classifier; /* ASR Classifier object. */
+ TfLiteTensor* m_outputTensor; /* Model output tensor. */
+ const std::vector<std::string>& m_labels; /* ASR Labels. */
+ asr::ResultVec & m_results; /* Results vector for a single inference. */
+ uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */
+ uint32_t m_outputInnerLen; /* Length of output inner context. */
+ uint32_t m_totalLen; /* Total length of the required axis. */
+ uint32_t m_countIterations; /* Current number of iterations. */
+ uint32_t m_blankTokenIdx; /* Index of the labels blank token. */
+ uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */
+
/**
- * @brief Checks if the tensor and axis index are valid
- * inputs to the object - based on how it has been
- * initialised.
- * @return true if valid, false otherwise.
+ * @brief Checks if the tensor and axis index are valid
+ * inputs to the object - based on how it has been initialised.
+ * @return true if valid, false otherwise.
*/
bool IsInputValid(TfLiteTensor* tensor,
- const uint32_t axisIdx) const;
+ uint32_t axisIdx) const;
/**
- * @brief Gets the tensor data element size in bytes based
- * on the tensor type.
- * @return Size in bytes, 0 if not supported.
+ * @brief Gets the tensor data element size in bytes based
+ * on the tensor type.
+ * @return Size in bytes, 0 if not supported.
*/
static uint32_t GetTensorElementSize(TfLiteTensor* tensor);
/**
- * @brief Erases sections from the data assuming row-wise
- * arrangement along the context axis.
- * @return true if successful, false otherwise.
+ * @brief Erases sections from the data assuming row-wise
+ * arrangement along the context axis.
+ * @return true if successful, false otherwise.
*/
bool EraseSectionsRowWise(uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration);
-
- /**
- * @brief Erases sections from the data assuming col-wise
- * arrangement along the context axis.
- * @return true if successful, false otherwise.
- */
- static bool EraseSectionsColWise(const uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration);
+ uint32_t strideSzBytes,
+ bool lastIteration);
};
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */
diff --git a/source/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/use_case/asr/include/Wav2LetterPreprocess.hpp
index 13d1589..8c12b3d 100644
--- a/source/use_case/asr/include/Wav2LetterPreprocess.hpp
+++ b/source/use_case/asr/include/Wav2LetterPreprocess.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,49 +21,44 @@
#include "Wav2LetterMfcc.hpp"
#include "AudioUtils.hpp"
#include "DataStructures.hpp"
+#include "BaseProcessing.hpp"
#include "log_macros.h"
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
/* Class to facilitate pre-processing calculation for Wav2Letter model
* for ASR. */
- using AudioWindow = SlidingWindow <const int16_t>;
+ using AudioWindow = audio::SlidingWindow<const int16_t>;
- class Preprocess {
+ class ASRPreProcess : public BasePreProcess {
public:
/**
* @brief Constructor.
- * @param[in] numMfccFeatures Number of MFCC features per window.
- * @param[in] windowLen Number of elements in a window.
- * @param[in] windowStride Stride (in number of elements) for
- * moving the window.
- * @param[in] numMfccVectors Number of MFCC vectors per window.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] numMfccFeatures Number of MFCC features per window.
+ * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window.
+ * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window.
+ * @param[in] mfccWindowStride Number of MFCC vectors that need to be calculated
+ * for an inference.
*/
- Preprocess(
- uint32_t numMfccFeatures,
- uint32_t windowLen,
- uint32_t windowStride,
- uint32_t numMfccVectors);
- Preprocess() = delete;
- ~Preprocess() = default;
+ ASRPreProcess(TfLiteTensor* inputTensor,
+ uint32_t numMfccFeatures,
+ uint32_t audioWindowLen,
+ uint32_t mfccWindowLen,
+ uint32_t mfccWindowStride);
/**
* @brief Calculates the features required from audio data. This
* includes MFCC, first and second order deltas,
* normalisation and finally, quantisation. The tensor is
- * populated with feature from a given window placed along
+ * populated with features from a given window placed along
* in a single row.
* @param[in] audioData Pointer to the first element of audio data.
* @param[in] audioDataLen Number of elements in the audio data.
- * @param[in] tensor Tensor to be populated.
* @return true if successful, false in case of error.
*/
- bool Invoke(const int16_t * audioData,
- uint32_t audioDataLen,
- TfLiteTensor * tensor);
+ bool DoPreProcess(const void* audioData, size_t audioDataLen) override;
protected:
/**
@@ -80,32 +75,16 @@ namespace asr {
Array2d<float>& delta2);
/**
- * @brief Given a 2D vector of floats, computes the mean.
- * @param[in] vec Vctor of vector of floats.
- * @return Mean value.
- */
- static float GetMean(Array2d<float>& vec);
-
- /**
- * @brief Given a 2D vector of floats, computes the stddev.
- * @param[in] vec Vector of vector of floats.
- * @param[in] mean Mean value of the vector passed in.
- * @return stddev value.
- */
- static float GetStdDev(Array2d<float>& vec,
- const float mean);
-
- /**
- * @brief Given a 2D vector of floats, normalises it using
- * the mean and the stddev.
+ * @brief Given a 2D vector of floats, rescale it to have mean of 0 and
+ * standard deviation of 1.
* @param[in,out] vec Vector of vector of floats.
*/
- static void NormaliseVec(Array2d<float>& vec);
+ static void StandardizeVecF32(Array2d<float>& vec);
/**
- * @brief Normalises the MFCC and delta buffers.
+ * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1.
*/
- void Normalise();
+ void Standarize();
/**
* @brief Given the quantisation and data type limits, computes
@@ -139,7 +118,7 @@ namespace asr {
*/
template <typename T>
bool Quantise(
- T * outputBuf,
+ T* outputBuf,
const uint32_t outputBufSz,
const float quantScale,
const int quantOffset)
@@ -160,15 +139,15 @@ namespace asr {
const float maxVal = std::numeric_limits<T>::max();
/* Need to transpose while copying and concatenating the tensor. */
- for (uint32_t j = 0; j < this->m_numFeatVectors; ++j) {
+ for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) {
for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) {
- *outputBufMfcc++ = static_cast<T>(Preprocess::GetQuantElem(
+ *outputBufMfcc++ = static_cast<T>(ASRPreProcess::GetQuantElem(
this->m_mfccBuf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD1++ = static_cast<T>(Preprocess::GetQuantElem(
+ *outputBufD1++ = static_cast<T>(ASRPreProcess::GetQuantElem(
this->m_delta1Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD2++ = static_cast<T>(Preprocess::GetQuantElem(
+ *outputBufD2++ = static_cast<T>(ASRPreProcess::GetQuantElem(
this->m_delta2Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
}
@@ -181,23 +160,22 @@ namespace asr {
}
private:
- Wav2LetterMFCC m_mfcc; /* MFCC instance. */
+ audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */
+ TfLiteTensor* m_inputTensor; /* Model input tensor. */
/* Actual buffers to be populated. */
- Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */
- Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */
- Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */
+ Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */
+ Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */
+ Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */
- uint32_t m_windowLen; /* Window length for MFCC. */
- uint32_t m_windowStride; /* Window stride len for MFCC. */
- uint32_t m_numMfccFeats; /* Number of MFCC features per window. */
- uint32_t m_numFeatVectors; /* Number of m_numMfccFeats. */
- AudioWindow m_window; /* Sliding window. */
+ uint32_t m_mfccWindowLen; /* Window length for MFCC. */
+ uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */
+ uint32_t m_numMfccFeats; /* Number of MFCC features per window. */
+ uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */
+ AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */
};
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */
diff --git a/source/use_case/asr/src/MainLoop.cc b/source/use_case/asr/src/MainLoop.cc
index 51b0b18..a1a9540 100644
--- a/source/use_case/asr/src/MainLoop.cc
+++ b/source/use_case/asr/src/MainLoop.cc
@@ -14,15 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "hal.h" /* Brings in platform definitions. */
#include "Labels.hpp" /* For label strings. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
#include "Wav2LetterModel.hpp" /* Model class for running inference. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
#include "AsrClassifier.hpp" /* Classifier. */
#include "InputFiles.hpp" /* Generated audio clip header. */
-#include "Wav2LetterPreprocess.hpp" /* Pre-processing class. */
-#include "Wav2LetterPostprocess.hpp" /* Post-processing class. */
#include "log_macros.h"
enum opcodes
@@ -48,23 +45,9 @@ static void DisplayMenu()
fflush(stdout);
}
-/** @brief Verify input and output tensor are of certain min dimensions. */
+/** @brief Verify input and output tensor are of certain min dimensions. */
static bool VerifyTensorDimensions(const arm::app::Model& model);
-/** @brief Gets the number of MFCC features for a single window. */
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model);
-
-/** @brief Gets the number of MFCC feature vectors to be computed. */
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model);
-
-/** @brief Gets the output context length (left and right) for post-processing. */
-static uint32_t GetOutputContextLen(const arm::app::Model& model,
- uint32_t inputCtxLen);
-
-/** @brief Gets the output inner length for post-processing. */
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
- uint32_t outputCtxLen);
-
void main_loop()
{
arm::app::Wav2LetterModel model; /* Model wrapper object. */
@@ -78,21 +61,6 @@ void main_loop()
return;
}
- /* Initialise pre-processing. */
- arm::app::audio::asr::Preprocess prep(
- GetNumMfccFeatures(model),
- g_FrameLength,
- g_FrameStride,
- GetNumMfccFeatureVectors(model));
-
- /* Initialise post-processing. */
- const uint32_t outputCtxLen = GetOutputContextLen(model, g_ctxLen);
- const uint32_t blankTokenIdx = 28;
- arm::app::audio::asr::Postprocess postp(
- outputCtxLen,
- GetOutputInnerLen(model, outputCtxLen),
- blankTokenIdx);
-
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
std::vector <std::string> labels;
@@ -109,8 +77,6 @@ void main_loop()
caseContext.Set<uint32_t>("ctxLen", g_ctxLen); /* Left and right context length (MFCC feat vectors). */
caseContext.Set<const std::vector <std::string>&>("labels", labels);
caseContext.Set<arm::app::AsrClassifier&>("classifier", classifier);
- caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep);
- caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp);
bool executionSuccessful = true;
constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false;
@@ -184,52 +150,3 @@ static bool VerifyTensorDimensions(const arm::app::Model& model)
return true;
}
-
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model)
-{
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx];
- if (0 != inputCols % 3) {
- printf_err("Number of input columns is not a multiple of 3\n");
- }
- return std::max(inputCols/3, 0);
-}
-
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model)
-{
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
- return std::max(inputRows, 0);
-}
-
-static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen)
-{
- const uint32_t inputRows = GetNumMfccFeatureVectors(model);
- const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
- constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
-
- /* Check to make sure that the input tensor supports the above
- * context and inner lengths. */
- if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
- printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
- inputCtxLen);
- return 0;
- }
-
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
-
- const float tensorColRatio = static_cast<float>(inputRows)/
- static_cast<float>(outputRows);
-
- return std::round(static_cast<float>(inputCtxLen)/tensorColRatio);
-}
-
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
- const uint32_t outputCtxLen)
-{
- constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
- return (outputRows - (2 * outputCtxLen));
-}
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc
index 420f725..7fe959b 100644
--- a/source/use_case/asr/src/UseCaseHandler.cc
+++ b/source/use_case/asr/src/UseCaseHandler.cc
@@ -20,7 +20,6 @@
#include "AsrClassifier.hpp"
#include "Wav2LetterModel.hpp"
#include "hal.h"
-#include "Wav2LetterMfcc.hpp"
#include "AudioUtils.hpp"
#include "ImageUtils.hpp"
#include "UseCaseCommonUtils.hpp"
@@ -34,68 +33,63 @@ namespace arm {
namespace app {
/**
- * @brief Presents inference results using the data presentation
- * object.
- * @param[in] results Vector of classification results to be displayed.
+ * @brief Presents ASR inference results.
+ * @param[in] results Vector of ASR classification results to be displayed.
* @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results);
+ static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results);
- /* Audio inference classification handler. */
+ /* ASR inference handler. */
bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
{
- constexpr uint32_t dataPsnTxtInfStartX = 20;
- constexpr uint32_t dataPsnTxtInfStartY = 40;
-
- hal_lcd_clear(COLOR_BLACK);
-
+ auto& model = ctx.Get<Model&>("model");
auto& profiler = ctx.Get<Profiler&>("profiler");
-
+ auto mfccFrameLen = ctx.Get<uint32_t>("frameLength");
+ auto mfccFrameStride = ctx.Get<uint32_t>("frameStride");
+ auto scoreThreshold = ctx.Get<float>("scoreThreshold");
+ auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
/* If the request has a valid size, set the audio index. */
if (clipIndex < NUMBER_OF_FILES) {
if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
return false;
}
}
+ auto initialClipIdx = ctx.Get<uint32_t>("clipIndex");
+ constexpr uint32_t dataPsnTxtInfStartX = 20;
+ constexpr uint32_t dataPsnTxtInfStartY = 40;
- /* Get model reference. */
- auto& model = ctx.Get<Model&>("model");
if (!model.IsInited()) {
printf_err("Model is not initialised! Terminating processing.\n");
return false;
}
- /* Get score threshold to be applied for the classifier (post-inference). */
- auto scoreThreshold = ctx.Get<float>("scoreThreshold");
-
- /* Get tensors. Dimensions of the tensor should have been verified by
+ /* Get input shape. Dimensions of the tensor should have been verified by
* the callee. */
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
+ TfLiteIntArray* inputShape = model.GetInputShape(0);
- /* Populate MFCC related parameters. */
- auto mfccParamsWinLen = ctx.Get<uint32_t>("frameLength");
- auto mfccParamsWinStride = ctx.Get<uint32_t>("frameStride");
-
- /* Populate ASR inference context and inner lengths for input. */
- auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
- const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
+ const uint32_t inputRowsSize = inputShape->data[Wav2LetterModel::ms_inputRowsIdx];
+ const uint32_t inputInnerLen = inputRowsSize - (2 * inputCtxLen);
/* Audio data stride corresponds to inputInnerLen feature vectors. */
- const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen);
- const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride;
- const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
+ const uint32_t audioDataWindowLen = (inputRowsSize - 1) * mfccFrameStride + (mfccFrameLen);
+ const uint32_t audioDataWindowStride = inputInnerLen * mfccFrameStride;
+
+ /* NOTE: This is only used for time stamp calculation. */
+ const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
- /* Get pre/post-processing objects. */
- auto& prep = ctx.Get<audio::asr::Preprocess&>("preprocess");
- auto& postp = ctx.Get<audio::asr::Postprocess&>("postprocess");
+ /* Set up pre and post-processing objects. */
+ ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures,
+ inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride);
- /* Set default reduction axis for post-processing. */
- const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
+ std::vector<ClassificationResult> singleInfResult;
+ const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen);
+ ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"),
+ model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"),
+ singleInfResult, outputCtxLen,
+ Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
+ );
- /* Audio clip start index. */
- auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
+ UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model);
/* Loop to process audio clips. */
do {
@@ -109,44 +103,41 @@ namespace app {
const uint32_t audioArrSize = get_audio_array_size(currentIndex);
if (!audioArr) {
- printf_err("Invalid audio array pointer\n");
+ printf_err("Invalid audio array pointer.\n");
return false;
}
- /* Audio clip must have enough samples to produce 1 MFCC feature. */
- if (audioArrSize < mfccParamsWinLen) {
+ /* Audio clip needs enough samples to produce at least 1 MFCC feature. */
+ if (audioArrSize < mfccFrameLen) {
printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
- mfccParamsWinLen);
+ mfccFrameLen);
return false;
}
- /* Initialise an audio slider. */
+ /* Creating a sliding window through the whole audio clip. */
auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
- audioArr,
- audioArrSize,
- audioParamsWinLen,
- audioParamsWinStride);
+ audioArr, audioArrSize,
+ audioDataWindowLen, audioDataWindowStride);
- /* Declare a container for results. */
- std::vector<arm::app::asr::AsrResult> results;
+ /* Declare a container for final results. */
+ std::vector<asr::AsrResult> finalResults;
/* Display message on the LCD - inference running. */
std::string str_inf{"Running inference... "};
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
get_filename(currentIndex));
- size_t inferenceWindowLen = audioParamsWinLen;
+ size_t inferenceWindowLen = audioDataWindowLen;
/* Start sliding through audio clip. */
while (audioDataSlider.HasNext()) {
- /* If not enough audio see how much can be sent for processing. */
+ /* If not enough audio, see how much can be sent for processing. */
size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
- if (nextStartIndex + audioParamsWinLen > audioArrSize) {
+ if (nextStartIndex + audioDataWindowLen > audioArrSize) {
inferenceWindowLen = audioArrSize - nextStartIndex;
}
@@ -155,46 +146,40 @@ namespace app {
info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
- /* Calculate MFCCs, deltas and populate the input tensor. */
- prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor);
+ /* Run the pre-processing, inference and post-processing. */
+ runner.PreProcess(inferenceWindow, inferenceWindowLen);
- /* Run inference over this audio clip sliding window. */
- if (!RunInference(model, profiler)) {
+ profiler.StartProfiling("Inference");
+ if (!runner.RunInference()) {
return false;
}
+ profiler.StopProfiling();
- /* Post-process. */
- postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext());
-
- /* Get results. */
- std::vector<ClassificationResult> classificationResult;
- auto& classifier = ctx.Get<AsrClassifier&>("classifier");
- classifier.GetClassificationResults(
- outputTensor, classificationResult,
- ctx.Get<std::vector<std::string>&>("labels"), 1);
+ postProcess.m_lastIteration = !audioDataSlider.HasNext();
+ if (!runner.PostProcess()) {
+ return false;
+ }
- results.emplace_back(asr::AsrResult(classificationResult,
- (audioDataSlider.Index() *
- audioParamsSecondsPerSample *
- audioParamsWinStride),
- audioDataSlider.Index(), scoreThreshold));
+ /* Add results from this window to our final results vector. */
+ finalResults.emplace_back(asr::AsrResult(singleInfResult,
+ (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride),
+ audioDataSlider.Index(), scoreThreshold));
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(outputTensor,
- outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+ armDumpTensor(outputTensor,
+ outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
#endif /* VERIFY_TEST_OUTPUT */
-
- }
+ } /* while (audioDataSlider.HasNext()) */
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
- hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
- ctx.Set<std::vector<arm::app::asr::AsrResult>>("results", results);
+ ctx.Set<std::vector<asr::AsrResult>>("results", finalResults);
- if (!PresentInferenceResult(results)) {
+ if (!PresentInferenceResult(finalResults)) {
return false;
}
@@ -202,13 +187,13 @@ namespace app {
IncrementAppCtxIfmIdx(ctx,"clipIndex");
- } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
+ } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx);
return true;
}
- static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results)
+ static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results)
{
constexpr uint32_t dataPsnTxtStartX1 = 20;
constexpr uint32_t dataPsnTxtStartY1 = 60;
@@ -219,15 +204,15 @@ namespace app {
info("Final results:\n");
info("Total number of inferences: %zu\n", results.size());
/* Results from multiple inferences should be combined before processing. */
- std::vector<arm::app::ClassificationResult> combinedResults;
- for (auto& result : results) {
+ std::vector<ClassificationResult> combinedResults;
+ for (const auto& result : results) {
combinedResults.insert(combinedResults.end(),
result.m_resultVec.begin(),
result.m_resultVec.end());
}
/* Get each inference result string using the decoder. */
- for (const auto & result : results) {
+ for (const auto& result : results) {
std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n",
@@ -238,10 +223,9 @@ namespace app {
/* Get the decoded result for the combined result. */
std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
- hal_lcd_display_text(
- finalResultStr.c_str(), finalResultStr.size(),
- dataPsnTxtStartX1, dataPsnTxtStartY1,
- allow_multiple_lines);
+ hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
+ dataPsnTxtStartX1, dataPsnTxtStartY1,
+ allow_multiple_lines);
info("Complete recognition: %s\n", finalResultStr.c_str());
return true;
diff --git a/source/use_case/asr/src/Wav2LetterMfcc.cc b/source/use_case/asr/src/Wav2LetterMfcc.cc
index 1bcaa66..bb29b0f 100644
--- a/source/use_case/asr/src/Wav2LetterMfcc.cc
+++ b/source/use_case/asr/src/Wav2LetterMfcc.cc
@@ -15,6 +15,7 @@
* limitations under the License.
*/
#include "Wav2LetterMfcc.hpp"
+
#include "PlatformMath.hpp"
#include "log_macros.h"
diff --git a/source/use_case/asr/src/Wav2LetterModel.cc b/source/use_case/asr/src/Wav2LetterModel.cc
index 766bce9..8b38f4f 100644
--- a/source/use_case/asr/src/Wav2LetterModel.cc
+++ b/source/use_case/asr/src/Wav2LetterModel.cc
@@ -15,6 +15,7 @@
* limitations under the License.
*/
#include "Wav2LetterModel.hpp"
+
#include "log_macros.h"
diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc
index 0392061..e3e1999 100644
--- a/source/use_case/asr/src/Wav2LetterPostprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,67 +15,71 @@
* limitations under the License.
*/
#include "Wav2LetterPostprocess.hpp"
+
#include "Wav2LetterModel.hpp"
#include "log_macros.h"
+#include <cmath>
+
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
-
- Postprocess::Postprocess(const uint32_t contextLen,
- const uint32_t innerLen,
- const uint32_t blankTokenIdx)
- : m_contextLen(contextLen),
- m_innerLen(innerLen),
- m_totalLen(2 * this->m_contextLen + this->m_innerLen),
+
+ ASRPostProcess::ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor,
+ const std::vector<std::string>& labels, std::vector<ClassificationResult>& results,
+ const uint32_t outputContextLen,
+ const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx
+ ):
+ m_classifier(classifier),
+ m_outputTensor(outputTensor),
+ m_labels{labels},
+ m_results(results),
+ m_outputContextLen(outputContextLen),
m_countIterations(0),
- m_blankTokenIdx(blankTokenIdx)
- {}
+ m_blankTokenIdx(blankTokenIdx),
+ m_reductionAxisIdx(reductionAxisIdx)
+ {
+ this->m_outputInnerLen = ASRPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
+ this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen);
+ }
- bool Postprocess::Invoke(TfLiteTensor* tensor,
- const uint32_t axisIdx,
- const bool lastIteration)
+ bool ASRPostProcess::DoPostProcess()
{
/* Basic checks. */
- if (!this->IsInputValid(tensor, axisIdx)) {
+ if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) {
return false;
}
/* Irrespective of tensor type, we use unsigned "byte" */
- uint8_t* ptrData = tflite::GetTensorData<uint8_t>(tensor);
- const uint32_t elemSz = this->GetTensorElementSize(tensor);
+ auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor);
+ const uint32_t elemSz = ASRPostProcess::GetTensorElementSize(this->m_outputTensor);
/* Other sanity checks. */
if (0 == elemSz) {
printf_err("Tensor type not supported for post processing\n");
return false;
- } else if (elemSz * this->m_totalLen > tensor->bytes) {
+ } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) {
printf_err("Insufficient number of tensor bytes\n");
return false;
}
/* Which axis do we need to process? */
- switch (axisIdx) {
- case arm::app::Wav2LetterModel::ms_outputRowsIdx:
- return this->EraseSectionsRowWise(ptrData,
- elemSz *
- tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx],
- lastIteration);
- case arm::app::Wav2LetterModel::ms_outputColsIdx:
- return this->EraseSectionsColWise(ptrData,
- elemSz *
- tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx],
- lastIteration);
+ switch (this->m_reductionAxisIdx) {
+ case Wav2LetterModel::ms_outputRowsIdx:
+ this->EraseSectionsRowWise(
+ ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx],
+ this->m_lastIteration);
+ break;
default:
- printf_err("Unsupported axis index: %" PRIu32 "\n", axisIdx);
+ printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx);
+ return false;
}
+ this->m_classifier.GetClassificationResults(this->m_outputTensor,
+ this->m_results, this->m_labels, 1);
- return false;
+ return true;
}
- bool Postprocess::IsInputValid(TfLiteTensor* tensor,
- const uint32_t axisIdx) const
+ bool ASRPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
{
if (nullptr == tensor) {
return false;
@@ -89,15 +93,15 @@ namespace asr {
if (static_cast<int>(this->m_totalLen) !=
tensor->dims->data[axisIdx]) {
- printf_err("Unexpected tensor dimension for axis %d, \n",
- tensor->dims->data[axisIdx]);
+ printf_err("Unexpected tensor dimension for axis %d, got %d, \n",
+ axisIdx, tensor->dims->data[axisIdx]);
return false;
}
return true;
}
- uint32_t Postprocess::GetTensorElementSize(TfLiteTensor* tensor)
+ uint32_t ASRPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
{
switch(tensor->type) {
case kTfLiteUInt8:
@@ -116,30 +120,30 @@ namespace asr {
return 0;
}
- bool Postprocess::EraseSectionsRowWise(
- uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration)
+ bool ASRPostProcess::EraseSectionsRowWise(
+ uint8_t* ptrData,
+ const uint32_t strideSzBytes,
+ const bool lastIteration)
{
/* In this case, the "zero-ing" is quite simple as the region
* to be zeroed sits in contiguous memory (row-major). */
- const uint32_t eraseLen = strideSzBytes * this->m_contextLen;
+ const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen;
/* Erase left context? */
if (this->m_countIterations > 0) {
/* Set output of each classification window to the blank token. */
std::memset(ptrData, 0, eraseLen);
- for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+ for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
}
}
/* Erase right context? */
if (false == lastIteration) {
- uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->m_contextLen + this->m_innerLen));
+ uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen));
/* Set output of each classification window to the blank token. */
std::memset(rightCtxPtr, 0, eraseLen);
- for (size_t windowIdx = 0; windowIdx < this->m_contextLen; windowIdx++) {
+ for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) {
rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1;
}
}
@@ -153,19 +157,56 @@ namespace asr {
return true;
}
- bool Postprocess::EraseSectionsColWise(
- const uint8_t* ptrData,
- const uint32_t strideSzBytes,
- const bool lastIteration)
+ uint32_t ASRPostProcess::GetNumFeatureVectors(const Model& model)
+ {
+ TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0);
+ if (inputRows == 0) {
+ printf_err("Error getting number of input rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_inputRowsIdx);
+ }
+ return inputRows;
+ }
+
+ uint32_t ASRPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
{
- /* Not implemented. */
- UNUSED(ptrData);
- UNUSED(strideSzBytes);
- UNUSED(lastIteration);
- return false;
+ const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0);
+ if (outputRows == 0) {
+ printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_outputRowsIdx);
+ }
+ int innerLen = (outputRows - (2 * outputCtxLen));
+
+ return std::max(innerLen, 0);
+ }
+
+ uint32_t ASRPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
+ {
+ const uint32_t inputRows = ASRPostProcess::GetNumFeatureVectors(model);
+ const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
+ constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx;
+
+ /* Check to make sure that the input tensor supports the above
+ * context and inner lengths. */
+ if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
+ printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
+ inputCtxLen);
+ return 0;
+ }
+
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+ const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
+ if (outputRows == 0) {
+ printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
+ Wav2LetterModel::ms_outputRowsIdx);
+ return 0;
+ }
+
+ const float inOutRowRatio = static_cast<float>(inputRows) /
+ static_cast<float>(outputRows);
+
+ return std::round(static_cast<float>(inputCtxLen) / inOutRowRatio);
}
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/asr/src/Wav2LetterPreprocess.cc b/source/use_case/asr/src/Wav2LetterPreprocess.cc
index e5ac3ca..590d08a 100644
--- a/source/use_case/asr/src/Wav2LetterPreprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPreprocess.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,37 +24,31 @@
namespace arm {
namespace app {
-namespace audio {
-namespace asr {
-
- Preprocess::Preprocess(
- const uint32_t numMfccFeatures,
- const uint32_t windowLen,
- const uint32_t windowStride,
- const uint32_t numMfccVectors):
- m_mfcc(numMfccFeatures, windowLen),
- m_mfccBuf(numMfccFeatures, numMfccVectors),
- m_delta1Buf(numMfccFeatures, numMfccVectors),
- m_delta2Buf(numMfccFeatures, numMfccVectors),
- m_windowLen(windowLen),
- m_windowStride(windowStride),
+
+ ASRPreProcess::ASRPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures,
+ const uint32_t numFeatureFrames, const uint32_t mfccWindowLen,
+ const uint32_t mfccWindowStride
+ ):
+ m_mfcc(numMfccFeatures, mfccWindowLen),
+ m_inputTensor(inputTensor),
+ m_mfccBuf(numMfccFeatures, numFeatureFrames),
+ m_delta1Buf(numMfccFeatures, numFeatureFrames),
+ m_delta2Buf(numMfccFeatures, numFeatureFrames),
+ m_mfccWindowLen(mfccWindowLen),
+ m_mfccWindowStride(mfccWindowStride),
m_numMfccFeats(numMfccFeatures),
- m_numFeatVectors(numMfccVectors),
- m_window()
+ m_numFeatureFrames(numFeatureFrames)
{
- if (numMfccFeatures > 0 && windowLen > 0) {
+ if (numMfccFeatures > 0 && mfccWindowLen > 0) {
this->m_mfcc.Init();
}
}
- bool Preprocess::Invoke(
- const int16_t* audioData,
- const uint32_t audioDataLen,
- TfLiteTensor* tensor)
+ bool ASRPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen)
{
- this->m_window = SlidingWindow<const int16_t>(
- audioData, audioDataLen,
- this->m_windowLen, this->m_windowStride);
+ this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(
+ static_cast<const int16_t*>(audioData), audioDataLen,
+ this->m_mfccWindowLen, this->m_mfccWindowStride);
uint32_t mfccBufIdx = 0;
@@ -62,12 +56,12 @@ namespace asr {
std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f);
std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f);
- /* While we can slide over the window. */
- while (this->m_window.HasNext()) {
- const int16_t* mfccWindow = this->m_window.Next();
+ /* While we can slide over the audio. */
+ while (this->m_mfccSlidingWindow.HasNext()) {
+ const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next();
auto mfccAudioData = std::vector<int16_t>(
mfccWindow,
- mfccWindow + this->m_windowLen);
+ mfccWindow + this->m_mfccWindowLen);
auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData);
for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) {
this->m_mfccBuf(i, mfccBufIdx) = mfcc[i];
@@ -76,11 +70,11 @@ namespace asr {
}
/* Pad MFCC if needed by adding MFCC for zeros. */
- if (mfccBufIdx != this->m_numFeatVectors) {
- std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_windowLen, 0);
+ if (mfccBufIdx != this->m_numFeatureFrames) {
+ std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_mfccWindowLen, 0);
std::vector<float> mfccZeros = this->m_mfcc.MfccCompute(zerosWindow);
- while (mfccBufIdx != this->m_numFeatVectors) {
+ while (mfccBufIdx != this->m_numFeatureFrames) {
memcpy(&this->m_mfccBuf(0, mfccBufIdx),
mfccZeros.data(), sizeof(float) * m_numMfccFeats);
++mfccBufIdx;
@@ -88,39 +82,37 @@ namespace asr {
}
/* Compute first and second order deltas from MFCCs. */
- Preprocess::ComputeDeltas(this->m_mfccBuf,
- this->m_delta1Buf,
- this->m_delta2Buf);
+ ASRPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf);
- /* Normalise. */
- this->Normalise();
+ /* Standardize calculated features. */
+ this->Standarize();
/* Quantise. */
- QuantParams quantParams = GetTensorQuantParams(tensor);
+ QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor);
if (0 == quantParams.scale) {
printf_err("Quantisation scale can't be 0\n");
return false;
}
- switch(tensor->type) {
+ switch(this->m_inputTensor->type) {
case kTfLiteUInt8:
return this->Quantise<uint8_t>(
- tflite::GetTensorData<uint8_t>(tensor), tensor->bytes,
+ tflite::GetTensorData<uint8_t>(this->m_inputTensor), this->m_inputTensor->bytes,
quantParams.scale, quantParams.offset);
case kTfLiteInt8:
return this->Quantise<int8_t>(
- tflite::GetTensorData<int8_t>(tensor), tensor->bytes,
+ tflite::GetTensorData<int8_t>(this->m_inputTensor), this->m_inputTensor->bytes,
quantParams.scale, quantParams.offset);
default:
printf_err("Unsupported tensor type %s\n",
- TfLiteTypeGetName(tensor->type));
+ TfLiteTypeGetName(this->m_inputTensor->type));
}
return false;
}
- bool Preprocess::ComputeDeltas(Array2d<float>& mfcc,
+ bool ASRPreProcess::ComputeDeltas(Array2d<float>& mfcc,
Array2d<float>& delta1,
Array2d<float>& delta2)
{
@@ -175,20 +167,10 @@ namespace asr {
return true;
}
- float Preprocess::GetMean(Array2d<float>& vec)
- {
- return math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
- }
-
- float Preprocess::GetStdDev(Array2d<float>& vec, const float mean)
- {
- return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
- }
-
- void Preprocess::NormaliseVec(Array2d<float>& vec)
+ void ASRPreProcess::StandardizeVecF32(Array2d<float>& vec)
{
- auto mean = Preprocess::GetMean(vec);
- auto stddev = Preprocess::GetStdDev(vec, mean);
+ auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
+ auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
debug("Mean: %f, Stddev: %f\n", mean, stddev);
if (stddev == 0) {
@@ -204,14 +186,14 @@ namespace asr {
}
}
- void Preprocess::Normalise()
+ void ASRPreProcess::Standarize()
{
- Preprocess::NormaliseVec(this->m_mfccBuf);
- Preprocess::NormaliseVec(this->m_delta1Buf);
- Preprocess::NormaliseVec(this->m_delta2Buf);
+ ASRPreProcess::StandardizeVecF32(this->m_mfccBuf);
+ ASRPreProcess::StandardizeVecF32(this->m_delta1Buf);
+ ASRPreProcess::StandardizeVecF32(this->m_delta2Buf);
}
- float Preprocess::GetQuantElem(
+ float ASRPreProcess::GetQuantElem(
const float elem,
const float quantScale,
const int quantOffset,
@@ -222,7 +204,5 @@ namespace asr {
return std::min<float>(std::max<float>(val, minVal), maxVal);
}
-} /* namespace asr */
-} /* namespace audio */
} /* namespace app */
} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp
index abf20ab..ddf38c1 100644
--- a/source/use_case/kws/include/KwsProcessing.hpp
+++ b/source/use_case/kws/include/KwsProcessing.hpp
@@ -38,7 +38,7 @@ namespace app {
public:
/**
* @brief Constructor
- * @param[in] model Pointer to the the KWS Model object.
+ * @param[in] model Pointer to the KWS Model object.
* @param[in] numFeatures How many MFCC features to use.
* @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
* sliding a window through the audio sample.
@@ -107,24 +107,21 @@ namespace app {
std::vector<ClassificationResult>& m_results;
public:
- const float m_scoreThreshold;
/**
- * @brief Constructor
- * @param[in] classifier Classifier object used to get top N results from classification.
- * @param[in] model Pointer to the the Image classification Model object.
- * @param[in] labels Vector of string labels to identify each output of the model.
- * @param[in] results Vector of classification results to store decoded outputs.
- * @param[in] scoreThreshold Predicted model score must be larger than this value to be accepted.
+ * @brief Constructor
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] model Pointer to the KWS Model object.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in/out] results Vector of classification results to store decoded outputs.
**/
KWSPostProcess(Classifier& classifier, Model* model,
const std::vector<std::string>& labels,
- std::vector<ClassificationResult>& results,
- float scoreThreshold);
+ std::vector<ClassificationResult>& results);
/**
- * @brief Should perform post-processing of the result of inference then populate
- * populate KWS result data for any later use.
- * @return true if successful, false otherwise.
+ * @brief Should perform post-processing of the result of inference then
+ * populate KWS result data for any later use.
+ * @return true if successful, false otherwise.
**/
bool DoPostProcess() override;
};
diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc
index b6b230c..14f9fce 100644
--- a/source/use_case/kws/src/KwsProcessing.cc
+++ b/source/use_case/kws/src/KwsProcessing.cc
@@ -197,11 +197,10 @@ namespace app {
KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model,
const std::vector<std::string>& labels,
- std::vector<ClassificationResult>& results, float scoreThreshold)
+ std::vector<ClassificationResult>& results)
:m_kwsClassifier{classifier},
m_labels{labels},
- m_results{results},
- m_scoreThreshold{scoreThreshold}
+ m_results{results}
{
if (!model->IsInited()) {
printf_err("Model is not initialised!.\n");
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index 350d34b..e73a2c3 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -93,7 +93,7 @@ namespace app {
std::vector<ClassificationResult> singleInfResult;
KWSPostProcess postprocess = KWSPostProcess(ctx.Get<KwsClassifier &>("classifier"), &model,
ctx.Get<std::vector<std::string>&>("labels"),
- singleInfResult, scoreThreshold);
+ singleInfResult);
UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
@@ -146,7 +146,7 @@ namespace app {
/* Add results from this window to our final results vector. */
finalResults.emplace_back(kws::KwsResult(singleInfResult,
audioDataSlider.Index() * secondsPerSample * preprocess.m_audioDataStride,
- audioDataSlider.Index(), postprocess.m_scoreThreshold));
+ audioDataSlider.Index(), scoreThreshold));
#if VERIFY_TEST_OUTPUT
TfLiteTensor* outputTensor = model.GetOutputTensor(0);
diff --git a/tests/common/PlatformMathTests.cpp b/tests/common/PlatformMathTests.cpp
index ab1153f..c07cbf1 100644
--- a/tests/common/PlatformMathTests.cpp
+++ b/tests/common/PlatformMathTests.cpp
@@ -150,13 +150,28 @@ TEST_CASE("Test SqrtF32")
TEST_CASE("Test MeanF32")
{
- /*Test Constants: */
+ /* Test Constants: */
std::vector<float> input
{0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 1.000};
/* Manually calculated mean of above vector */
float expectedResult = 0.100;
CHECK (expectedResult == Approx(arm::app::math::MathUtils::MeanF32(input.data(), input.size())));
+
+ /* Mean of 0 */
+ std::vector<float> input2{1, 2, -1, -2};
+ float expectedResult2 = 0.0f;
+ CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::MeanF32(input2.data(), input2.size())));
+
+ /* All 0s */
+ std::vector<float> input3 = std::vector<float>(9, 0);
+ float expectedResult3 = 0.0f;
+ CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::MeanF32(input3.data(), input3.size())));
+
+ /* All 1s */
+ std::vector<float> input4 = std::vector<float>(9, 1);
+ float expectedResult4 = 1.0f;
+ CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::MeanF32(input4.data(), input4.size())));
}
TEST_CASE("Test StdDevF32")
@@ -184,6 +199,22 @@ TEST_CASE("Test StdDevF32")
float expectedResult = 0.969589282958136;
CHECK (expectedResult == Approx(output));
+
+ /* All 0s should have 0 std dev. */
+ std::vector<float> input2 = std::vector<float>(4, 0);
+ float expectedResult2 = 0.0f;
+ CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::StdDevF32(input2.data(), input2.size(), 0.0f)));
+
+ /* All 1s should have 0 std dev. */
+ std::vector<float> input3 = std::vector<float>(4, 1);
+ float expectedResult3 = 0.0f;
+ CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::StdDevF32(input3.data(), input3.size(), 1.0f)));
+
+ /* Manually calclualted std value */
+ std::vector<float> input4 {1, 2, 3, 4, 5, 6, 7, 8, 9, 0};
+ float mean2 = (std::accumulate(input4.begin(), input4.end(), 0.0f))/float(input4.size());
+ float expectedResult4 = 2.872281323;
+ CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::StdDevF32(input4.data(), input4.size(), mean2)));
}
TEST_CASE("Test FFT32")
diff --git a/tests/use_case/asr/AsrFeaturesTests.cc b/tests/use_case/asr/AsrFeaturesTests.cc
index 940c25f..6c23598 100644
--- a/tests/use_case/asr/AsrFeaturesTests.cc
+++ b/tests/use_case/asr/AsrFeaturesTests.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,29 +23,19 @@
#include <catch.hpp>
#include <random>
-class TestPreprocess : public arm::app::audio::asr::Preprocess {
+class TestPreprocess : public arm::app::ASRPreProcess {
public:
static bool ComputeDeltas(arm::app::Array2d<float>& mfcc,
arm::app::Array2d<float>& delta1,
arm::app::Array2d<float>& delta2)
{
- return Preprocess::ComputeDeltas(mfcc, delta1, delta2);
- }
-
- static float GetMean(arm::app::Array2d<float>& vec)
- {
- return Preprocess::GetMean(vec);
- }
-
- static float GetStdDev(arm::app::Array2d<float>& vec, const float mean)
- {
- return Preprocess::GetStdDev(vec, mean);
+ return ASRPreProcess::ComputeDeltas(mfcc, delta1, delta2);
}
static void NormaliseVec(arm::app::Array2d<float>& vec)
{
- return Preprocess::NormaliseVec(vec);
+ return ASRPreProcess::StandardizeVecF32(vec);
}
};
@@ -126,40 +116,6 @@ TEST_CASE("Floating point asr features calculation", "[ASR]")
}
- SECTION("Mean")
- {
- std::vector<std::vector<float>> mean1vec{{1, 2},
- {-1, -2}};
- arm::app::Array2d<float> mean1(2,2); /* {{1, 2},{-1, -2}} */
- populateArray2dWithVectorOfVector(mean1vec, mean1);
- REQUIRE(0 == Approx(TestPreprocess::GetMean(mean1)));
-
- arm::app::Array2d<float> mean2(2, 2);
- std::fill(mean2.begin(), mean2.end(), 0.f);
- REQUIRE(0 == Approx(TestPreprocess::GetMean(mean2)));
-
- arm::app::Array2d<float> mean3(3,3);
- std::fill(mean3.begin(), mean3.end(), 1.f);
- REQUIRE(1 == Approx(TestPreprocess::GetMean(mean3)));
- }
-
- SECTION("Std")
- {
- arm::app::Array2d<float> std1(2, 2);
- std::fill(std1.begin(), std1.end(), 0.f); /* {{0, 0}, {0, 0}} */
- REQUIRE(0 == Approx(TestPreprocess::GetStdDev(std1, 0)));
-
- std::vector<std::vector<float>> std2vec{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 0}};
- arm::app::Array2d<float> std2(2,5);
- populateArray2dWithVectorOfVector(std2vec, std2);
- const float mean = TestPreprocess::GetMean(std2);
- REQUIRE(2.872281323 == Approx(TestPreprocess::GetStdDev(std2, mean)));
-
- arm::app::Array2d<float> std3(2,2);
- std::fill(std3.begin(), std3.end(), 1.f); /* std3{{1, 1}, {1, 1}}; */
- REQUIRE(0 == Approx(TestPreprocess::GetStdDev(std3, 1)));
- }
-
SECTION("Norm") {
auto checker = [&](arm::app::Array2d<float>& d, std::vector<float>& g) {
TestPreprocess::NormaliseVec(d);
diff --git a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc
index 9ed2e1b..d0b6505 100644
--- a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc
+++ b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,6 +16,7 @@
*/
#include "Wav2LetterPostprocess.hpp"
#include "Wav2LetterModel.hpp"
+#include "ClassificationResult.hpp"
#include <algorithm>
#include <catch.hpp>
@@ -47,85 +48,105 @@ TEST_CASE("Checking return value")
{
SECTION("Mismatched post processing parameters and tensor size")
{
- const uint32_t ctxLen = 5;
- const uint32_t innerLen = 3;
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
-
+ const uint32_t outputCtxLen = 5;
+ arm::app::AsrClassifier classifier;
+ arm::app::Wav2LetterModel model;
+ model.Init();
+ std::vector<std::string> dummyLabels = {"a", "b", "$"};
+ const uint32_t blankTokenIdx = 2;
+ std::vector<arm::app::ClassificationResult> dummyResult;
std::vector <int> tensorShape = {1, 1, 1, 13};
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
- REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ tensorShape, 100, tensorVec);
+
+ arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
+
+ REQUIRE(!post.DoPostProcess());
}
SECTION("Post processing succeeds")
{
- const uint32_t ctxLen = 5;
- const uint32_t innerLen = 3;
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0};
-
- std::vector <int> tensorShape = {1, 1, 13, 1};
- std::vector <int8_t> tensorVec;
+ const uint32_t outputCtxLen = 5;
+ arm::app::AsrClassifier classifier;
+ arm::app::Wav2LetterModel model;
+ model.Init();
+ std::vector<std::string> dummyLabels = {"a", "b", "$"};
+ const uint32_t blankTokenIdx = 2;
+ std::vector<arm::app::ClassificationResult> dummyResult;
+ std::vector<int> tensorShape = {1, 1, 13, 1};
+ std::vector<int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ tensorShape, 100, tensorVec);
+
+ arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
- std::vector <int8_t> originalVec = tensorVec;
+ std::vector<int8_t> originalVec = tensorVec;
/* This step should not erase anything. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ REQUIRE(post.DoPostProcess());
}
}
TEST_CASE("Postprocessing - erasing required elements")
{
- constexpr uint32_t ctxLen = 5;
+ constexpr uint32_t outputCtxLen = 5;
constexpr uint32_t innerLen = 3;
- constexpr uint32_t nRows = 2*ctxLen + innerLen;
+ constexpr uint32_t nRows = 2*outputCtxLen + innerLen;
constexpr uint32_t nCols = 10;
constexpr uint32_t blankTokenIdx = nCols - 1;
- std::vector <int> tensorShape = {1, 1, nRows, nCols};
+ std::vector<int> tensorShape = {1, 1, nRows, nCols};
+ arm::app::AsrClassifier classifier;
+ arm::app::Wav2LetterModel model;
+ model.Init();
+ std::vector<std::string> dummyLabels = {"a", "b", "$"};
+ std::vector<arm::app::ClassificationResult> dummyResult;
SECTION("First and last iteration")
{
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
- std::vector <int8_t> tensorVec;
- TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ std::vector<int8_t> tensorVec;
+ TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec);
+ arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
- std::vector <int8_t> originalVec = tensorVec;
+ std::vector<int8_t>originalVec = tensorVec;
/* This step should not erase anything. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
+ post.m_lastIteration = true;
+ REQUIRE(post.DoPostProcess());
REQUIRE(originalVec == tensorVec);
}
SECTION("Right context erase")
{
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ tensorShape, 100, tensorVec);
+ arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
- std::vector <int8_t> originalVec = tensorVec;
+ std::vector<int8_t> originalVec = tensorVec;
+ //auto tensorData = tflite::GetTensorData<int8_t>(tensor);
/* This step should erase the right context only. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ post.m_lastIteration = false;
+ REQUIRE(post.DoPostProcess());
REQUIRE(originalVec != tensorVec);
/* The last ctxLen * 10 elements should be gone. */
- for (size_t i = 0; i < ctxLen; ++i) {
+ for (size_t i = 0; i < outputCtxLen; ++i) {
for (size_t j = 0; j < nCols; ++j) {
- /* Check right context elements are zeroed. */
+ /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */
if (j == blankTokenIdx) {
- CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1);
+ CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1);
} else {
- CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0);
+ CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0);
}
/* Check left context is preserved. */
@@ -134,46 +155,47 @@ TEST_CASE("Postprocessing - erasing required elements")
}
/* Check inner elements are preserved. */
- for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
+ for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) {
CHECK(tensorVec[i] == originalVec[i]);
}
}
SECTION("Left and right context erase")
{
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ tensorShape, 100, tensorVec);
+ arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
std::vector <int8_t> originalVec = tensorVec;
/* This step should erase right context. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ post.m_lastIteration = false;
+ REQUIRE(post.DoPostProcess());
/* Calling it the second time should erase the left context. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false));
+ REQUIRE(post.DoPostProcess());
REQUIRE(originalVec != tensorVec);
/* The first and last ctxLen * 10 elements should be gone. */
- for (size_t i = 0; i < ctxLen; ++i) {
+ for (size_t i = 0; i < outputCtxLen; ++i) {
for (size_t j = 0; j < nCols; ++j) {
/* Check left and right context elements are zeroed. */
if (j == blankTokenIdx) {
- CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1);
+ CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1);
CHECK(tensorVec[i*nCols + j] == 1);
} else {
- CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0);
+ CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0);
CHECK(tensorVec[i*nCols + j] == 0);
}
}
}
/* Check inner elements are preserved. */
- for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) {
+ for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) {
/* Check left context is preserved. */
CHECK(tensorVec[i] == originalVec[i]);
}
@@ -181,18 +203,20 @@ TEST_CASE("Postprocessing - erasing required elements")
SECTION("Try left context erase")
{
- /* Should not be able to erase the left context if it is the first iteration. */
- arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx};
-
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
- tensorShape, 100, tensorVec);
+ tensorShape, 100, tensorVec);
+
+ /* Should not be able to erase the left context if it is the first iteration. */
+ arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
std::vector <int8_t> originalVec = tensorVec;
/* Calling it the second time should erase the left context. */
- REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true));
+ post.m_lastIteration = true;
+ REQUIRE(post.DoPostProcess());
REQUIRE(originalVec == tensorVec);
}
diff --git a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc
index 457257f..0280af6 100644
--- a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc
+++ b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,55 +24,46 @@ constexpr uint32_t numMfccVectors = 10;
/* Test vector output: generated using test-asr-preprocessing.py. */
int8_t expectedResult[numMfccVectors][numMfccFeatures * 3] = {
- /* Feature vec 0. */
- -32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, /* Delta 2. */
-
- /* Feature vec 1. */
- -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 2. */
- -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 3. */
- -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 4 : this should have valid delta 1 and delta 2. */
- -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
- -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5,
- -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15,
-
- /* Feature vec 5 : this should have valid delta 1 and delta 2. */
- -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12,
- -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13,
- -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6,
-
- /* Feature vec 6. */
- -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 7. */
- -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 8. */
- -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-
- /* Feature vec 9. */
- -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
- -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
- -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10
+ /* Feature vec 0. */
+ {-32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, /* Delta 2. */
+ /* Feature vec 1. */
+ {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 2. */
+ {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 3. */
+ {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 4 : this should have valid delta 1 and delta 2. */
+ {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+ -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5,
+ -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15},
+ /* Feature vec 5 : this should have valid delta 1 and delta 2. */
+ {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12,
+ -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13,
+ -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6},
+ /* Feature vec 6. */
+ {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 7. */
+ {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 8. */
+ {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10},
+ /* Feature vec 9. */
+ {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11,
+ -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11,
+ -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}
};
void PopulateTestWavVector(std::vector<int16_t>& vec)
@@ -97,15 +88,16 @@ void PopulateTestWavVector(std::vector<int16_t>& vec)
TEST_CASE("Preprocessing calculation INT8")
{
/* Constants. */
- const uint32_t windowLen = 512;
- const uint32_t windowStride = 160;
- int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors};
- const float quantScale = 0.1410219967365265;
- const int quantOffset = -11;
+ const uint32_t mfccWindowLen = 512;
+ const uint32_t mfccWindowStride = 160;
+ int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors};
+ const float quantScale = 0.1410219967365265;
+ const int quantOffset = -11;
/* Test wav memory. */
- std::vector <int16_t> testWav((windowStride * numMfccVectors) +
- (windowLen - windowStride));
+ std::vector<int16_t> testWav((mfccWindowStride * numMfccVectors) +
+ (mfccWindowLen - mfccWindowStride)
+ );
/* Populate with dummy input. */
PopulateTestWavVector(testWav);
@@ -115,20 +107,20 @@ TEST_CASE("Preprocessing calculation INT8")
/* Initialise dimensions and the test tensor. */
TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray);
- TfLiteTensor tensor = tflite::testing::CreateQuantizedTensor(
- tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput");
+ TfLiteTensor inputTensor = tflite::testing::CreateQuantizedTensor(
+ tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput");
/* Initialise pre-processing module. */
- arm::app::audio::asr::Preprocess prep{
- numMfccFeatures, windowLen, windowStride, numMfccVectors};
+ arm::app::ASRPreProcess prep{&inputTensor,
+ numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride};
/* Invoke pre-processing. */
- REQUIRE(prep.Invoke(testWav.data(), testWav.size(), &tensor));
+ REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size()));
/* Wrap the tensor with a std::vector for ease. */
- auto* tensorData = tflite::GetTensorData<int8_t>(&tensor);
+ auto* tensorData = tflite::GetTensorData<int8_t>(&inputTensor);
std::vector <int8_t> vecResults =
- std::vector<int8_t>(tensorData, tensorData + tensor.bytes);
+ std::vector<int8_t>(tensorData, tensorData + inputTensor.bytes);
/* Check sizes. */
REQUIRE(vecResults.size() == sizeof(expectedResult));