summaryrefslogtreecommitdiff
path: root/source/use_case/asr/src/UseCaseHandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/asr/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/asr/src/UseCaseHandler.cc38
1 files changed, 22 insertions, 16 deletions
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc
index 7fe959b..850bdc2 100644
--- a/source/use_case/asr/src/UseCaseHandler.cc
+++ b/source/use_case/asr/src/UseCaseHandler.cc
@@ -33,9 +33,9 @@ namespace arm {
namespace app {
/**
- * @brief Presents ASR inference results.
- * @param[in] results Vector of ASR classification results to be displayed.
- * @return true if successful, false otherwise.
+ * @brief Presents ASR inference results.
+ * @param[in] results Vector of ASR classification results to be displayed.
+ * @return true if successful, false otherwise.
**/
static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results);
@@ -63,6 +63,9 @@ namespace app {
return false;
}
+ TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
+
/* Get input shape. Dimensions of the tensor should have been verified by
* the callee. */
TfLiteIntArray* inputShape = model.GetInputShape(0);
@@ -78,19 +81,19 @@ namespace app {
const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
/* Set up pre and post-processing objects. */
- ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures,
- inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride);
+ AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures,
+ inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
+ mfccFrameLen, mfccFrameStride);
std::vector<ClassificationResult> singleInfResult;
- const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen);
- ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"),
- model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"),
+ const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen);
+ AsrPostProcess postProcess = AsrPostProcess(
+ outputTensor, ctx.Get<AsrClassifier&>("classifier"),
+ ctx.Get<std::vector<std::string>&>("labels"),
singleInfResult, outputCtxLen,
Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
);
- UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model);
-
/* Loop to process audio clips. */
do {
hal_lcd_clear(COLOR_BLACK);
@@ -147,16 +150,20 @@ namespace app {
static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
/* Run the pre-processing, inference and post-processing. */
- runner.PreProcess(inferenceWindow, inferenceWindowLen);
+ if (!preProcess.DoPreProcess(inferenceWindow, inferenceWindowLen)) {
+ printf_err("Pre-processing failed.");
+ return false;
+ }
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
+ /* Post processing needs to know if we are on the last audio window. */
postProcess.m_lastIteration = !audioDataSlider.HasNext();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
@@ -166,7 +173,6 @@ namespace app {
audioDataSlider.Index(), scoreThreshold));
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
armDumpTensor(outputTensor,
outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
#endif /* VERIFY_TEST_OUTPUT */