summaryrefslogtreecommitdiff
path: root/source/use_case/asr/src
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/asr/src')
-rw-r--r--source/use_case/asr/src/AsrClassifier.cc196
-rw-r--r--source/use_case/asr/src/UseCaseHandler.cc38
-rw-r--r--source/use_case/asr/src/Wav2LetterPostprocess.cc24
-rw-r--r--source/use_case/asr/src/Wav2LetterPreprocess.cc28
4 files changed, 151 insertions, 135 deletions
diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc
index 84e66b7..4ba8c7b 100644
--- a/source/use_case/asr/src/AsrClassifier.cc
+++ b/source/use_case/asr/src/AsrClassifier.cc
@@ -20,117 +20,125 @@
#include "TensorFlowLiteMicro.hpp"
#include "Wav2LetterModel.hpp"
-template<typename T>
-bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint)
-{
- const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx];
- const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx];
-
- if (nLetters != labels.size()) {
- printf("Output size doesn't match the labels' size\n");
- return false;
- }
+namespace arm {
+namespace app {
+
+ template<typename T>
+ bool AsrClassifier::GetTopResults(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels, double scale, double zeroPoint)
+ {
+ const uint32_t nElems = tensor->dims->data[Wav2LetterModel::ms_outputRowsIdx];
+ const uint32_t nLetters = tensor->dims->data[Wav2LetterModel::ms_outputColsIdx];
+
+ if (nLetters != labels.size()) {
+ printf("Output size doesn't match the labels' size\n");
+ return false;
+ }
- /* NOTE: tensor's size verification against labels should be
- * checked by the calling/public function. */
- if (nLetters < 1) {
- return false;
- }
+ /* NOTE: tensor's size verification against labels should be
+ * checked by the calling/public function. */
+ if (nLetters < 1) {
+ return false;
+ }
- /* Final results' container. */
- vecResults = std::vector<ClassificationResult>(nElems);
+ /* Final results' container. */
+ vecResults = std::vector<ClassificationResult>(nElems);
- T* tensorData = tflite::GetTensorData<T>(tensor);
+ T* tensorData = tflite::GetTensorData<T>(tensor);
- /* Get the top 1 results. */
- for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
- std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
+ /* Get the top 1 results. */
+ for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) {
+ std::pair<T, uint32_t> top_1 = std::make_pair(tensorData[row + 0], 0);
- for (uint32_t j = 1; j < nLetters; ++j) {
- if (top_1.first < tensorData[row + j]) {
- top_1.first = tensorData[row + j];
- top_1.second = j;
+ for (uint32_t j = 1; j < nLetters; ++j) {
+ if (top_1.first < tensorData[row + j]) {
+ top_1.first = tensorData[row + j];
+ top_1.second = j;
+ }
}
+
+ double score = static_cast<int> (top_1.first);
+ vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
+ vecResults[i].m_label = labels[top_1.second];
+ vecResults[i].m_labelIdx = top_1.second;
}
- double score = static_cast<int> (top_1.first);
- vecResults[i].m_normalisedVal = scale * (score - zeroPoint);
- vecResults[i].m_label = labels[top_1.second];
- vecResults[i].m_labelIdx = top_1.second;
+ return true;
}
-
- return true;
-}
-template bool arm::app::AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
-template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
- std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, double scale, double zeroPoint);
-
-bool arm::app::AsrClassifier::GetClassificationResults(
+ template bool AsrClassifier::GetTopResults<uint8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels,
+ double scale, double zeroPoint);
+ template bool AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tensor,
+ std::vector<ClassificationResult>& vecResults,
+ const std::vector <std::string>& labels,
+ double scale, double zeroPoint);
+
+ bool AsrClassifier::GetClassificationResults(
TfLiteTensor* outputTensor,
std::vector<ClassificationResult>& vecResults,
const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax)
-{
- UNUSED(use_softmax);
- vecResults.clear();
+ {
+ UNUSED(use_softmax);
+ vecResults.clear();
- constexpr int minTensorDims = static_cast<int>(
- (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)?
- arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx);
+ constexpr int minTensorDims = static_cast<int>(
+ (Wav2LetterModel::ms_outputRowsIdx > Wav2LetterModel::ms_outputColsIdx)?
+ Wav2LetterModel::ms_outputRowsIdx : Wav2LetterModel::ms_outputColsIdx);
- constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx;
+ constexpr uint32_t outColsIdx = Wav2LetterModel::ms_outputColsIdx;
- /* Sanity checks. */
- if (outputTensor == nullptr) {
- printf_err("Output vector is null pointer.\n");
- return false;
- } else if (outputTensor->dims->size < minTensorDims) {
- printf_err("Output tensor expected to be %dD\n", minTensorDims);
- return false;
- } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
- printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount);
- return false;
- } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
- printf("Output size doesn't match the labels' size\n");
- return false;
- }
+ /* Sanity checks. */
+ if (outputTensor == nullptr) {
+ printf_err("Output vector is null pointer.\n");
+ return false;
+ } else if (outputTensor->dims->size < minTensorDims) {
+ printf_err("Output tensor expected to be %dD\n", minTensorDims);
+ return false;
+ } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) < topNCount) {
+ printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount);
+ return false;
+ } else if (static_cast<uint32_t>(outputTensor->dims->data[outColsIdx]) != labels.size()) {
+ printf("Output size doesn't match the labels' size\n");
+ return false;
+ }
- if (topNCount != 1) {
- warn("TopNCount value ignored in this implementation\n");
- }
+ if (topNCount != 1) {
+ warn("TopNCount value ignored in this implementation\n");
+ }
- /* To return the floating point values, we need quantization parameters. */
- QuantParams quantParams = GetTensorQuantParams(outputTensor);
-
- bool resultState;
-
- switch (outputTensor->type) {
- case kTfLiteUInt8:
- resultState = this->GetTopResults<uint8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
- break;
- case kTfLiteInt8:
- resultState = this->GetTopResults<int8_t>(
- outputTensor, vecResults,
- labels, quantParams.scale,
- quantParams.offset);
- break;
- default:
- printf_err("Tensor type %s not supported by classifier\n",
- TfLiteTypeGetName(outputTensor->type));
+ /* To return the floating point values, we need quantization parameters. */
+ QuantParams quantParams = GetTensorQuantParams(outputTensor);
+
+ bool resultState;
+
+ switch (outputTensor->type) {
+ case kTfLiteUInt8:
+ resultState = this->GetTopResults<uint8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
+ break;
+ case kTfLiteInt8:
+ resultState = this->GetTopResults<int8_t>(
+ outputTensor, vecResults,
+ labels, quantParams.scale,
+ quantParams.offset);
+ break;
+ default:
+ printf_err("Tensor type %s not supported by classifier\n",
+ TfLiteTypeGetName(outputTensor->type));
+ return false;
+ }
+
+ if (!resultState) {
+ printf_err("Failed to get sorted set\n");
return false;
- }
+ }
- if (!resultState) {
- printf_err("Failed to get sorted set\n");
- return false;
- }
+ return true;
+ }
- return true;
-} \ No newline at end of file
+} /* namespace app */
+} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc
index 7fe959b..850bdc2 100644
--- a/source/use_case/asr/src/UseCaseHandler.cc
+++ b/source/use_case/asr/src/UseCaseHandler.cc
@@ -33,9 +33,9 @@ namespace arm {
namespace app {
/**
- * @brief Presents ASR inference results.
- * @param[in] results Vector of ASR classification results to be displayed.
- * @return true if successful, false otherwise.
+ * @brief Presents ASR inference results.
+ * @param[in] results Vector of ASR classification results to be displayed.
+ * @return true if successful, false otherwise.
**/
static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results);
@@ -63,6 +63,9 @@ namespace app {
return false;
}
+ TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+
/* Get input shape. Dimensions of the tensor should have been verified by
* the callee. */
TfLiteIntArray* inputShape = model.GetInputShape(0);
@@ -78,19 +81,19 @@ namespace app {
const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
/* Set up pre and post-processing objects. */
- ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures,
- inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride);
+ AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures,
+ inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
+ mfccFrameLen, mfccFrameStride);
std::vector<ClassificationResult> singleInfResult;
- const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen);
- ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"),
- model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"),
+ const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen);
+ AsrPostProcess postProcess = AsrPostProcess(
+ outputTensor, ctx.Get<AsrClassifier&>("classifier"),
+ ctx.Get<std::vector<std::string>&>("labels"),
singleInfResult, outputCtxLen,
Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
);
- UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model);
-
/* Loop to process audio clips. */
do {
hal_lcd_clear(COLOR_BLACK);
@@ -147,16 +150,20 @@ namespace app {
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
/* Run the pre-processing, inference and post-processing. */
- runner.PreProcess(inferenceWindow, inferenceWindowLen);
+ if (!preProcess.DoPreProcess(inferenceWindow, inferenceWindowLen)) {
+ printf_err("Pre-processing failed.");
+ return false;
+ }
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
+ /* Post processing needs to know if we are on the last audio window. */
postProcess.m_lastIteration = !audioDataSlider.HasNext();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
@@ -166,7 +173,6 @@ namespace app {
audioDataSlider.Index(), scoreThreshold));
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
armDumpTensor(outputTensor,
outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
#endif /* VERIFY_TEST_OUTPUT */
diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc
index e3e1999..42f434e 100644
--- a/source/use_case/asr/src/Wav2LetterPostprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc
@@ -24,7 +24,7 @@
namespace arm {
namespace app {
- ASRPostProcess::ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor,
+ AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier,
const std::vector<std::string>& labels, std::vector<ClassificationResult>& results,
const uint32_t outputContextLen,
const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx
@@ -38,11 +38,11 @@ namespace app {
m_blankTokenIdx(blankTokenIdx),
m_reductionAxisIdx(reductionAxisIdx)
{
- this->m_outputInnerLen = ASRPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
+ this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen);
this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen);
}
- bool ASRPostProcess::DoPostProcess()
+ bool AsrPostProcess::DoPostProcess()
{
/* Basic checks. */
if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) {
@@ -51,7 +51,7 @@ namespace app {
/* Irrespective of tensor type, we use unsigned "byte" */
auto* ptrData = tflite::GetTensorData<uint8_t>(this->m_outputTensor);
- const uint32_t elemSz = ASRPostProcess::GetTensorElementSize(this->m_outputTensor);
+ const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor);
/* Other sanity checks. */
if (0 == elemSz) {
@@ -79,7 +79,7 @@ namespace app {
return true;
}
- bool ASRPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
+ bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const
{
if (nullptr == tensor) {
return false;
@@ -101,7 +101,7 @@ namespace app {
return true;
}
- uint32_t ASRPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
+ uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor)
{
switch(tensor->type) {
case kTfLiteUInt8:
@@ -120,7 +120,7 @@ namespace app {
return 0;
}
- bool ASRPostProcess::EraseSectionsRowWise(
+ bool AsrPostProcess::EraseSectionsRowWise(
uint8_t* ptrData,
const uint32_t strideSzBytes,
const bool lastIteration)
@@ -157,7 +157,7 @@ namespace app {
return true;
}
- uint32_t ASRPostProcess::GetNumFeatureVectors(const Model& model)
+ uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model)
{
TfLiteTensor* inputTensor = model.GetInputTensor(0);
const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0);
@@ -168,21 +168,23 @@ namespace app {
return inputRows;
}
- uint32_t ASRPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
+ uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen)
{
const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0);
if (outputRows == 0) {
printf_err("Error getting number of output rows for axis: %" PRIu32 "\n",
Wav2LetterModel::ms_outputRowsIdx);
}
+
+ /* Watching for underflow. */
int innerLen = (outputRows - (2 * outputCtxLen));
return std::max(innerLen, 0);
}
- uint32_t ASRPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
+ uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen)
{
- const uint32_t inputRows = ASRPostProcess::GetNumFeatureVectors(model);
+ const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model);
const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx;
diff --git a/source/use_case/asr/src/Wav2LetterPreprocess.cc b/source/use_case/asr/src/Wav2LetterPreprocess.cc
index 590d08a..92b0631 100644
--- a/source/use_case/asr/src/Wav2LetterPreprocess.cc
+++ b/source/use_case/asr/src/Wav2LetterPreprocess.cc
@@ -25,9 +25,9 @@
namespace arm {
namespace app {
- ASRPreProcess::ASRPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures,
- const uint32_t numFeatureFrames, const uint32_t mfccWindowLen,
- const uint32_t mfccWindowStride
+ AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures,
+ const uint32_t numFeatureFrames, const uint32_t mfccWindowLen,
+ const uint32_t mfccWindowStride
):
m_mfcc(numMfccFeatures, mfccWindowLen),
m_inputTensor(inputTensor),
@@ -44,7 +44,7 @@ namespace app {
}
}
- bool ASRPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen)
+ bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen)
{
this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(
static_cast<const int16_t*>(audioData), audioDataLen,
@@ -82,7 +82,7 @@ namespace app {
}
/* Compute first and second order deltas from MFCCs. */
- ASRPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf);
+ AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf);
/* Standardize calculated features. */
this->Standarize();
@@ -112,9 +112,9 @@ namespace app {
return false;
}
- bool ASRPreProcess::ComputeDeltas(Array2d<float>& mfcc,
- Array2d<float>& delta1,
- Array2d<float>& delta2)
+ bool AsrPreProcess::ComputeDeltas(Array2d<float>& mfcc,
+ Array2d<float>& delta1,
+ Array2d<float>& delta2)
{
const std::vector <float> delta1Coeffs =
{6.66666667e-02, 5.00000000e-02, 3.33333333e-02,
@@ -167,7 +167,7 @@ namespace app {
return true;
}
- void ASRPreProcess::StandardizeVecF32(Array2d<float>& vec)
+ void AsrPreProcess::StandardizeVecF32(Array2d<float>& vec)
{
auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize());
auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean);
@@ -186,14 +186,14 @@ namespace app {
}
}
- void ASRPreProcess::Standarize()
+ void AsrPreProcess::Standarize()
{
- ASRPreProcess::StandardizeVecF32(this->m_mfccBuf);
- ASRPreProcess::StandardizeVecF32(this->m_delta1Buf);
- ASRPreProcess::StandardizeVecF32(this->m_delta2Buf);
+ AsrPreProcess::StandardizeVecF32(this->m_mfccBuf);
+ AsrPreProcess::StandardizeVecF32(this->m_delta1Buf);
+ AsrPreProcess::StandardizeVecF32(this->m_delta2Buf);
}
- float ASRPreProcess::GetQuantElem(
+ float AsrPreProcess::GetQuantElem(
const float elem,
const float quantScale,
const int quantOffset,