diff options
author | Richard Burton <richard.burton@arm.com> | 2022-04-22 16:14:57 +0100 |
---|---|---|
committer | Richard Burton <richard.burton@arm.com> | 2022-04-22 16:14:57 +0100 |
commit | b40ecf8522052809d2351677a96195d69e4d0c16 (patch) | |
tree | 8647dfdae7bcae0ec6d9564ba7a971819fdda431 /source/use_case/asr/src/UseCaseHandler.cc | |
parent | c291144b7f08c21d08cdaf79cc64dc420ca70070 (diff) | |
download | ml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz |
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments
Removed the inference runner code as it wasn't used
Fixes to doc strings
Consistent naming e.g. Asr/Kws instead of ASR/KWS
Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'source/use_case/asr/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/asr/src/UseCaseHandler.cc | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc index 7fe959b..850bdc2 100644 --- a/source/use_case/asr/src/UseCaseHandler.cc +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -33,9 +33,9 @@ namespace arm { namespace app { /** - * @brief Presents ASR inference results. - * @param[in] results Vector of ASR classification results to be displayed. - * @return true if successful, false otherwise. + * @brief Presents ASR inference results. + * @param[in] results Vector of ASR classification results to be displayed. + * @return true if successful, false otherwise. **/ static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results); @@ -63,6 +63,9 @@ namespace app { return false; } + TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + /* Get input shape. Dimensions of the tensor should have been verified by * the callee. */ TfLiteIntArray* inputShape = model.GetInputShape(0); @@ -78,19 +81,19 @@ namespace app { const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq); /* Set up pre and post-processing objects. */ - ASRPreProcess preProcess = ASRPreProcess(model.GetInputTensor(0), Wav2LetterModel::ms_numMfccFeatures, - inputShape->data[Wav2LetterModel::ms_inputRowsIdx], mfccFrameLen, mfccFrameStride); + AsrPreProcess preProcess = AsrPreProcess(inputTensor, Wav2LetterModel::ms_numMfccFeatures, + inputShape->data[Wav2LetterModel::ms_inputRowsIdx], + mfccFrameLen, mfccFrameStride); std::vector<ClassificationResult> singleInfResult; - const uint32_t outputCtxLen = ASRPostProcess::GetOutputContextLen(model, inputCtxLen); - ASRPostProcess postProcess = ASRPostProcess(ctx.Get<AsrClassifier&>("classifier"), - model.GetOutputTensor(0), ctx.Get<std::vector<std::string>&>("labels"), + const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen); + AsrPostProcess postProcess = AsrPostProcess( + outputTensor, ctx.Get<AsrClassifier&>("classifier"), + ctx.Get<std::vector<std::string>&>("labels"), singleInfResult, outputCtxLen, Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx ); - UseCaseRunner runner = UseCaseRunner(&preProcess, &postProcess, &model); - /* Loop to process audio clips. */ do { hal_lcd_clear(COLOR_BLACK); @@ -147,16 +150,20 @@ namespace app { static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); /* Run the pre-processing, inference and post-processing. */ - runner.PreProcess(inferenceWindow, inferenceWindowLen); + if (!preProcess.DoPreProcess(inferenceWindow, inferenceWindowLen)) { + printf_err("Pre-processing failed."); + return false; + } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); + /* Post processing needs to know if we are on the last audio window. */ postProcess.m_lastIteration = !audioDataSlider.HasNext(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } @@ -166,7 +173,6 @@ namespace app { audioDataSlider.Index(), scoreThreshold)); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); armDumpTensor(outputTensor, outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]); #endif /* VERIFY_TEST_OUTPUT */ |