diff options
Diffstat (limited to 'source/use_case/asr/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/asr/src/UseCaseHandler.cc | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc index 7fe959b..850bdc2 100644 --- a/source/use_case/asr/src/UseCaseHandler.cc +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -33,9 +33,9 @@ namespace arm { namespace app { /** - * @brief Presents ASR inference results. - * @param[in] results Vector of ASR classification results to be displayed. - * @return true if successful, false otherwise. + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. + * @return true if successful, false otherwise. **/ static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results); @@ -63,6 +63,9 @@ namespace app { return false; } + TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + /* Get input shape. Dimensions of the tensor should have been verified by * the callee. */ TfLiteIntArray* inputShape = model.GetInputShape(0); @@ -78,19 +81,19 @@ namespace app { const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq); /* Set up pre and post-processing objects. */ - ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures, - inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride); + AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + mfccFrameLen, mfccFrameStride); std::vector<ClassificationResult> singleInfResult; - const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen); - ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"), - model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"), + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen); + AsrPostProcess postProcess = AsrPostProcess( + outputTensor, ctx.Get<AsrClassifier&>("classifier"), + ctx.Get<std::vector<std::string>&>("labels"), singleInfResult, outputCtxLen, Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx ); - UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model); - /* Loop to process audio clips. */ do { hal_lcd_clear(COLOR_BLACK); @@ -147,16 +150,20 @@ namespace app { static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); /* Run the pre-processing, inference and post-processing. */ - runner.PreProcess(inferenceWindow, inferenceWindowLen); + if (!preProcess.DoPreProcess(inferenceWindow, inferenceWindowLen)) { + printf_err("Pre-processing failed."); + return false; + } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); + /* Post processing needs to know if we are on the last audio window. */ postProcess.m_lastIteration = !audioDataSlider.HasNext(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } @@ -166,7 +173,6 @@ namespace app { audioDataSlider.Index(), scoreThreshold)); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); armDumpTensor(outputTensor, outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ |