summaryrefslogtreecommitdiff
path: root/source/use_case/asr
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
commitb40ecf8522052809d2351677a96195d69e4d0c16 (patch)
tree8647dfdae7bcae0ec6d9564ba7a971819fdda431 /source/use_case/asr
parentc291144b7f08c21d08cdaf79cc64dc420ca70070 (diff)
downloadml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments Removed the inference runner code as it wasn't used Fixes to doc strings Consistent naming e.g. Asr/Kws instead of ASR/KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'source/use_case/asr')
-rw-r--r--source/use_case/asr/include/AsrClassifier.hpp10
-rw-r--r--source/use_case/asr/include/Wav2LetterModel.hpp1
-rw-r--r--source/use_case/asr/include/Wav2LetterPostprocess.hpp15
-rw-r--r--source/use_case/asr/include/Wav2LetterPreprocess.hpp28
-rw-r--r--source/use_case/asr/src/AsrClassifier.cc196
-rw-r--r--source/use_case/asr/src/UseCaseHandler.cc38
-rw-r--r--source/use_case/asr/src/Wav2LetterPostprocess.cc24
-rw-r--r--source/use_case/asr/src/Wav2LetterPreprocess.cc28
8 files changed, 179 insertions, 161 deletions
diff --git a/source/use_case/asr/include/AsrClassifier.hpp b/source/use_case/asr/include/AsrClassifier.hpp
index 67a200e..a07a721 100644
--- a/source/use_case/asr/include/AsrClassifier.hpp
+++ b/source/use_case/asr/include/AsrClassifier.hpp
@@ -35,10 +35,10 @@ namespace app {
* @param[in] use_softmax Whether softmax scaling should be applied to model output.
* @return true if successful, false otherwise.
**/
- bool GetClassificationResults(
- TfLiteTensor* outputTensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax = false) override;
+ bool GetClassificationResults(TfLiteTensor* outputTensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector<std::string>& labels,
+ uint32_t topNCount, bool use_softmax = false) override;
private:
/**
@@ -54,7 +54,7 @@ namespace app {
template<typename T>
bool GetTopResults(TfLiteTensor* tensor,
std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
+ const std::vector<std::string>& labels, double scale, double zeroPoint);
};
} /* namespace app */
diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp
index 895df2b..0078e44 100644
--- a/source/use_case/asr/include/Wav2LetterModel.hpp
+++ b/source/use_case/asr/include/Wav2LetterModel.hpp
@@ -36,6 +36,7 @@ namespace app {
static constexpr uint32_t ms_outputRowsIdx = 2;
static constexpr uint32_t ms_outputColsIdx = 3;
+ /* Model specific constants. */
static constexpr uint32_t ms_blankTokenIdx = 28;
static constexpr uint32_t ms_numMfccFeatures = 13;
diff --git a/source/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/use_case/asr/include/Wav2LetterPostprocess.hpp
index 45defa5..446014d 100644
--- a/source/use_case/asr/include/Wav2LetterPostprocess.hpp
+++ b/source/use_case/asr/include/Wav2LetterPostprocess.hpp
@@ -30,23 +30,24 @@ namespace app {
* @brief Helper class to manage tensor post-processing for "wav2letter"
* output.
*/
- class ASRPostProcess : public BasePostProcess {
+ class AsrPostProcess : public BasePostProcess {
public:
bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */
/**
* @brief Constructor
- * @param[in] outputTensor Pointer to the output Tensor.
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Object used to get top N results from classification.
* @param[in] labels Vector of string labels to identify each output of the model.
- * @param[in/out] result Vector of classification results to store decoded outputs.
+ * @param[in/out] result Vector of classification results to store decoded outputs.
* @param[in] outputContextLen Left/right context length for output tensor.
* @param[in] blankTokenIdx Index in the labels that the "Blank token" takes.
* @param[in] reductionAxis The axis that the logits of each time step is on.
**/
- ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor,
- const std::vector<std::string>& labels, asr::ResultVec& result,
- uint32_t outputContextLen,
- uint32_t blankTokenIdx, uint32_t reductionAxis);
+ AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
+ const std::vector<std::string>& labels, asr::ResultVec& result,
+ uint32_t outputContextLen,
+ uint32_t blankTokenIdx, uint32_t reductionAxis);
/**
* @brief Should perform post-processing of the result of inference then
diff --git a/source/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/use_case/asr/include/Wav2LetterPreprocess.hpp
index 8c12b3d..dc9a415 100644
--- a/source/use_case/asr/include/Wav2LetterPreprocess.hpp
+++ b/source/use_case/asr/include/Wav2LetterPreprocess.hpp
@@ -31,22 +31,22 @@ namespace app {
* for ASR. */
using AudioWindow = audio::SlidingWindow<const int16_t>;
- class ASRPreProcess : public BasePreProcess {
+ class AsrPreProcess : public BasePreProcess {
public:
/**
* @brief Constructor.
* @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
* @param[in] numMfccFeatures Number of MFCC features per window.
+ * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
+ * for an inference.
* @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window.
* @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window.
- * @param[in] mfccWindowStride Number of MFCC vectors that need to be calculated
- * for an inference.
*/
- ASRPreProcess(TfLiteTensor* inputTensor,
- uint32_t numMfccFeatures,
- uint32_t audioWindowLen,
- uint32_t mfccWindowLen,
- uint32_t mfccWindowStride);
+ AsrPreProcess(TfLiteTensor* inputTensor,
+ uint32_t numMfccFeatures,
+ uint32_t numFeatureFrames,
+ uint32_t mfccWindowLen,
+ uint32_t mfccWindowStride);
/**
* @brief Calculates the features required from audio data. This
@@ -130,9 +130,9 @@ namespace app {
}
/* Populate. */
- T * outputBufMfcc = outputBuf;
- T * outputBufD1 = outputBuf + this->m_numMfccFeats;
- T * outputBufD2 = outputBufD1 + this->m_numMfccFeats;
+ T* outputBufMfcc = outputBuf;
+ T* outputBufD1 = outputBuf + this->m_numMfccFeats;
+ T* outputBufD2 = outputBufD1 + this->m_numMfccFeats;
const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */
const float minVal = std::numeric_limits<T>::min();
@@ -141,13 +141,13 @@ namespace app {
/* Need to transpose while copying and concatenating the tensor. */
for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) {
for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) {
- *outputBufMfcc++ = static_cast<T>(ASRPreProcess::GetQuantElem(
+ *outputBufMfcc++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_mfccBuf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD1++ = static_cast<T>(ASRPreProcess::GetQuantElem(
+ *outputBufD1++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_delta1Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
- *outputBufD2++ = static_cast<T>(ASRPreProcess::GetQuantElem(
+ *outputBufD2++ = static_cast<T>(AsrPreProcess::GetQuantElem(
this->m_delta2Buf(i, j), quantScale,
quantOffset, minVal, maxVal));
}
diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc
index 84e66b7..4ba8c7b 100644
--- a/source/use_case/asr/src/AsrClassifier.cc
+++ b/source/use_case/asr/src/AsrClassifier.cc
@@ -20,117 +20,125 @@
#include "TensorFlowLiteMicro.hpp"
#include "Wav2LetterModel.hpp"
-template<typename T>
-bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint)
-{
- const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
- const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
-
- if (nLetters != labels.size()) {
- printf("Output size doesn't match the labels' size\n");
- return false;
- }
+namespace arm {
+namespace app {
+
+ template<typename T>
+ bool AsrClassifier::GetTopResults(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint)
+ {
+ const uint32_t nElems = tensor->dims->data[Wav2LetterModel::ms_outputRowsIdx];
+ const uint32_t nLetters = tensor->dims->data[Wav2LetterModel::ms_outputColsIdx];
+
+ if (nLetters != labels.size()) {
+ printf("Output size doesn't match the labels' size\n");
+ return false;
+ }
- /* NOTE: tensor's size verification against labels should be
- * checked by the calling/public function. */
- if (nLetters < 1) {
- return false;
- }
+ /* NOTE: tensor's size verification against labels should be
+ * checked by the calling/public function. */
+ if (nLetters < 1) {
+ return false;
+ }
- /* Final results' container. */
- vecResults = std::vector<ClassificationResult>(nElems);
+ /* Final results' container. */
+ vecResults = std::vector<ClassificationResult>(nElems);
- T* tensorData = tflite::GetTensorData<T>(tensor);
+ T* tensorData = tflite::GetTensorData<T>(tensor);
- /* Get the top 1 results. */
- for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
- std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
+ /* Get the top 1 results. */
+ for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
+ std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
- for (uint32_t j = 1; j < nLetters; ++j) {
- if (top_1.first < tensorData[row + j]) {
- top_1.first = tensorData[row + j];
- top_1.second = j;
+ for (uint32_t j = 1; j < nLetters; ++j) {
+ if (top_1.first < tensorData[row + j]) {
+ top_1.first = tensorData[row + j];
+ top_1.second = j;
+ }
}
+
+ double score = static_cast<int> (top_1.first);
+ vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
+ vecResults[i].m_label = labels[top_1.second];
+ vecResults[i].m_labelIdx = top_1.second;
}
- double score = static_cast<int> (top_1.first);
- vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
- vecResults[i].m_label = labels[top_1.second];
- vecResults[i].m_labelIdx = top_1.second;
+ return true;
}
-
- return true;
-}
-template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
-template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
-
-bool arm::app::AsrClassifier::GetClassificationResults(
+ template bool AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels,
+ double scale, double zeroPoint);
+ template bool AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels,
+ double scale, double zeroPoint);
+
+ bool AsrClassifier::GetClassificationResults(
TfLiteTensor* outputTensor,
std::vector<ClassificationResult>& vecResults,
const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax)
-{
- UNUSED(use_softmax);
- vecResults.clear();
+ {
+ UNUSED(use_softmax);
+ vecResults.clear();
- constexpr int minTensorDims = static_cast<int>(
- (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)?
- arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx);
+ constexpr int minTensorDims = static_cast<int>(
+ (Wav2LetterModel::ms_outputRowsIdx > Wav2LetterModel::ms_outputColsIdx)?
+ Wav2LetterModel::ms_outputRowsIdx : Wav2LetterModel::ms_outputColsIdx);
- constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx;
+ constexpr uint32_t outColsIdx = Wav2LetterModel::ms_outputColsIdx;
- /* Sanity checks. */
- if (outputTensor == nullptr) {
- printf_err("Output vector is null pointer.\n");
- return false;
- } else if (outputTensor->dims->size < minTensorDims) {
- printf_err("Output tensor expected to be %dD\n", minTensorDims);
- return false;
- } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
- printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount);
- return false;
- } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
- printf("Output size doesn't match the labels' size\n");
- return false;
- }
+ /* Sanity checks. */
+ if (outputTensor == nullptr) {
+ printf_err("Output vector is null pointer.\n");
+ return false;
+ } else if (outputTensor->dims->size < minTensorDims) {
+ printf_err("Output tensor expected to be %dD\n", minTensorDims);
+ return false;
+ } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
+ printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount);
+ return false;
+ } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
+ printf("Output size doesn't match the labels' size\n");
+ return false;
+ }
- if (topNCount != 1) {
- warn("TopNCount value ignored in this implementation\n");
- }
+ if (topNCount != 1) {
+ warn("TopNCount value ignored in this implementation\n");
+ }
- /* To return the floating point values, we need quantization parameters. */
- QuantParams quantParams = GetTensorQuantParams(outputTensor);
-
- bool resultState;
-
- switch (outputTensor->type) {
- case kTfLiteUInt8:
- resultState = this->GetTopResults<uint8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
- break;
- case kTfLiteInt8:
- resultState = this->GetTopResults<int8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
- break;
- default:
- printf_err("Tensor type %s not supported by classifier\n",
- TfLiteTypeGetName(outputTensor->type));
+ /* To return the floating point values, we need quantization parameters. */
+ QuantParams quantParams = GetTensorQuantParams(outputTensor);
+
+ bool resultState;
+
+ switch (outputTensor->type) {
+ case kTfLiteUInt8:
+ resultState = this->GetTopResults<uint8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
+ break;
+ case kTfLiteInt8:
+ resultState = this->GetTopResults<int8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
+ break;
+ default:
+ printf_err("Tensor type %s not supported by classifier\n",
+ TfLiteTypeGetName(outputTensor->type));
+ return false;
+ }
+
+ if (!resultState) {
+ printf_err("Failed to get sorted set\n");
return false;
- }
+ }
- if (!resultState) {
- printf_err("Failed to get sorted set\n");
- return false;
- }
+ return true;
+ }
- return true;
-} \ No newline at end of file
+} /* namespace app */
+} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc
index 7fe959b..850bdc2 100644
--- a/source/use_case/asr/src/UseCaseHandler.cc
+++ b/source/use_case/asr/src/UseCaseHandler.cc
@@ -33,9 +33,9 @@ namespace arm {
namespace app {
/**
- * @brief Presents ASR inference results.
- * @param[in] results Vector of ASR classification results to be displayed.
- * @return true if successful, false otherwise.
+ * @brief Presents ASR inference results.
+ * @param[in] results Vector of ASR classification results to be displayed.
+ * @return true if successful, false otherwise.
**/
static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results);
@@ -63,6 +63,9 @@ namespace app {
return false;
}
+ TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+
/* Get input shape. Dimensions of the tensor should have been verified by
* the callee. */
TfLiteIntArray* inputShape = model.GetInputShape(0);
@@ -78,19 +81,19 @@ namespace app {
const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
/* Set up pre and post-processing objects. */
- ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures,
- inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride);
+ AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures,
+ inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
+ mfccFrameLen, mfccFrameStride);
std::vector<ClassificationResult> singleInfResult;
- const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen);
- ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"),
- model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"),
+ const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen);
+ AsrPostProcess postProcess = AsrPostProcess(
+ outputTensor, ctx.Get<AsrClassifier&>("classifier"),
+ ctx.Get<std::vector<std::string>&>("labels"),
singleInfResult, outputCtxLen,
Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
);
- UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model);
-
/* Loop to process audio clips. */
do {
hal_lcd_clear(COLOR_BLACK);
@@ -147,16 +150,20 @@ namespace app {
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
/* Run the pre-processing, inference and post-processing. */
- runner.PreProcess(inferenceWindow, inferenceWindowLen);
+ if (!preProcess.DoPreProcess(inferenceWindow, inferenceWindowLen)) {
+ printf_err("Pre-processing failed.");
+ return false;
+ }
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
+ /* Post processing needs to know if we are on the last audio window. */
postProcess.m_lastIteration = !audioDataSlider.HasNext();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
@@ -166,7 +173,6 @@ namespace app {
audioDataSlider.Index(), scoreThreshold));
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
armDumpTensor(outputTensor,
outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
#endif /* VERIFY_TEST_OUTPUT */
diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc
index e3e1999..42f434e 100644
--- a/source/use_case/asr/src/Wav2LetterPostprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc
@@ -24,7 +24,7 @@
namespace arm {
namespace app {
- ASRPostProcess::ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor,
+ AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
const std::vector<std::string>& labels, std::vector<ClassificationResult>& results,
const uint32_t outputContextLen,
const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx
@@ -38,11 +38,11 @@ namespace app {
m_blankTokenIdx(blankTokenIdx),
m_reductionAxisIdx(reductionAxisIdx)
{
- this->m_outputInnerLen = ASRPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
+ this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen);
}
- bool ASRPostProcess::DoPostProcess()
+ bool AsrPostProcess::DoPostProcess()
{
/* Basic checks. */
if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) {
@@ -51,7 +51,7 @@ namespace app {
/* Irrespective of tensor type, we use unsigned "byte" */
auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor);
- const uint32_t elemSz = ASRPostProcess::GetTensorElementSize(this->m_outputTensor);
+ const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor);
/* Other sanity checks. */
if (0 == elemSz) {
@@ -79,7 +79,7 @@ namespace app {
return true;
}
- bool ASRPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
+ bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
{
if (nullptr == tensor) {
return false;
@@ -101,7 +101,7 @@ namespace app {
return true;
}
- uint32_t ASRPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
+ uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
{
switch(tensor->type) {
case kTfLiteUInt8:
@@ -120,7 +120,7 @@ namespace app {
return 0;
}
- bool ASRPostProcess::EraseSectionsRowWise(
+ bool AsrPostProcess::EraseSectionsRowWise(
uint8_t* ptrData,
const uint32_t strideSzBytes,
const bool lastIteration)
@@ -157,7 +157,7 @@ namespace app {
return true;
}
- uint32_t ASRPostProcess::GetNumFeatureVectors(const Model& model)
+ uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model)
{
TfLiteTensor* inputTensor = model.GetInputTensor(0);
const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0);
@@ -168,21 +168,23 @@ namespace app {
return inputRows;
}
- uint32_t ASRPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
+ uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
{
const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0);
if (outputRows == 0) {
printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
Wav2LetterModel::ms_outputRowsIdx);
}
+
+ /* Watching for underflow. */
int innerLen = (outputRows - (2 * outputCtxLen));
return std::max(innerLen, 0);
}
- uint32_t ASRPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
+ uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
{
- const uint32_t inputRows = ASRPostProcess::GetNumFeatureVectors(model);
+ const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model);
const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx;
diff --git a/source/use_case/asr/src/Wav2LetterPreprocess.cc b/source/use_case/asr/src/Wav2LetterPreprocess.cc
index 590d08a..92b0631 100644
--- a/source/use_case/asr/src/Wav2LetterPreprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPreprocess.cc
@@ -25,9 +25,9 @@
namespace arm {
namespace app {
- ASRPreProcess::ASRPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures,
- const uint32_t numFeatureFrames, const uint32_t mfccWindowLen,
- const uint32_t mfccWindowStride
+ AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures,
+ const uint32_t numFeatureFrames, const uint32_t mfccWindowLen,
+ const uint32_t mfccWindowStride
):
m_mfcc(numMfccFeatures, mfccWindowLen),
m_inputTensor(inputTensor),
@@ -44,7 +44,7 @@ namespace app {
}
}
- bool ASRPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen)
+ bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen)
{
this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(
static_cast<const int16_t*>(audioData), audioDataLen,
@@ -82,7 +82,7 @@ namespace app {
}
/* Compute first and second order deltas from MFCCs. */
- ASRPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf);
+ AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf);
/* Standardize calculated features. */
this->Standarize();
@@ -112,9 +112,9 @@ namespace app {
return false;
}
- bool ASRPreProcess::ComputeDeltas(Array2d<float>& mfcc,
- Array2d<float>& delta1,
- Array2d<float>& delta2)
+ bool AsrPreProcess::ComputeDeltas(Array2d<float>& mfcc,
+ Array2d<float>& delta1,
+ Array2d<float>& delta2)
{
const std::vector <float> delta1Coeffs =
{6.66666667e-02, 5.00000000e-02, 3.33333333e-02,
@@ -167,7 +167,7 @@ namespace app {
return true;
}
- void ASRPreProcess::StandardizeVecF32(Array2d<float>& vec)
+ void AsrPreProcess::StandardizeVecF32(Array2d<float>& vec)
{
auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
@@ -186,14 +186,14 @@ namespace app {
}
}
- void ASRPreProcess::Standarize()
+ void AsrPreProcess::Standarize()
{
- ASRPreProcess::StandardizeVecF32(this->m_mfccBuf);
- ASRPreProcess::StandardizeVecF32(this->m_delta1Buf);
- ASRPreProcess::StandardizeVecF32(this->m_delta2Buf);
+ AsrPreProcess::StandardizeVecF32(this->m_mfccBuf);
+ AsrPreProcess::StandardizeVecF32(this->m_delta1Buf);
+ AsrPreProcess::StandardizeVecF32(this->m_delta2Buf);
}
- float ASRPreProcess::GetQuantElem(
+ float AsrPreProcess::GetQuantElem(
const float elem,
const float quantScale,
const int quantOffset,