diff options
Diffstat (limited to 'source/use_case/asr/src/MainLoop.cc')
-rw-r--r-- | source/use_case/asr/src/MainLoop.cc | 85 |
1 files changed, 1 insertions, 84 deletions
diff --git a/source/use_case/asr/src/MainLoop.cc b/source/use_case/asr/src/MainLoop.cc index 51b0b18..a1a9540 100644 --- a/source/use_case/asr/src/MainLoop.cc +++ b/source/use_case/asr/src/MainLoop.cc @@ -14,15 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions. */ #include "Labels.hpp" /* For label strings. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ #include "Wav2LetterModel.hpp" /* Model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "AsrClassifier.hpp" /* Classifier. */ #include "InputFiles.hpp" /* Generated audio clip header. */ -#include "Wav2LetterPreprocess.hpp" /* Pre-processing class. */ -#include "Wav2LetterPostprocess.hpp" /* Post-processing class. */ #include "log_macros.h" enum opcodes @@ -48,23 +45,9 @@ static void DisplayMenu() fflush(stdout); } -/** @brief Verify input and output tensor are of certain min dimensions. */ +/** @brief Verify input and output tensor are of certain min dimensions. */ static bool VerifyTensorDimensions(const arm::app::Model& model); -/** @brief Gets the number of MFCC features for a single window. */ -static uint32_t GetNumMfccFeatures(const arm::app::Model& model); - -/** @brief Gets the number of MFCC feature vectors to be computed. */ -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); - -/** @brief Gets the output context length (left and right) for post-processing. */ -static uint32_t GetOutputContextLen(const arm::app::Model& model, - uint32_t inputCtxLen); - -/** @brief Gets the output inner length for post-processing. */ -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - uint32_t outputCtxLen); - void main_loop() { arm::app::Wav2LetterModel model; /* Model wrapper object. */ @@ -78,21 +61,6 @@ void main_loop() return; } - /* Initialise pre-processing. */ - arm::app::audio::asr::Preprocess prep( - GetNumMfccFeatures(model), - g_FrameLength, - g_FrameStride, - GetNumMfccFeatureVectors(model)); - - /* Initialise post-processing. */ - const uint32_t outputCtxLen = GetOutputContextLen(model, g_ctxLen); - const uint32_t blankTokenIdx = 28; - arm::app::audio::asr::Postprocess postp( - outputCtxLen, - GetOutputInnerLen(model, outputCtxLen), - blankTokenIdx); - /* Instantiate application context. */ arm::app::ApplicationContext caseContext; std::vector <std::string> labels; @@ -109,8 +77,6 @@ void main_loop() caseContext.Set<uint32_t>("ctxLen", g_ctxLen); /* Left and right context length (MFCC feat vectors). */ caseContext.Set<const std::vector <std::string>&>("labels", labels); caseContext.Set<arm::app::AsrClassifier&>("classifier", classifier); - caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep); - caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp); bool executionSuccessful = true; constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; @@ -184,52 +150,3 @@ static bool VerifyTensorDimensions(const arm::app::Model& model) return true; } - -static uint32_t GetNumMfccFeatures(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; - if (0 != inputCols % 3) { - printf_err("Number of input columns is not a multiple of 3\n"); - } - return std::max(inputCols/3, 0); -} - -static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; - return std::max(inputRows, 0); -} - -static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) -{ - const uint32_t inputRows = GetNumMfccFeatureVectors(model); - const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - - /* Check to make sure that the input tensor supports the above - * context and inner lengths. */ - if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { - printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", - inputCtxLen); - return 0; - } - - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - - const float tensorColRatio = static_cast<float>(inputRows)/ - static_cast<float>(outputRows); - - return std::round(static_cast<float>(inputCtxLen)/tensorColRatio); -} - -static uint32_t GetOutputInnerLen(const arm::app::Model& model, - const uint32_t outputCtxLen) -{ - constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); - return (outputRows - (2 * outputCtxLen)); -} |