summaryrefslogtreecommitdiff
path: root/source/use_case/asr/src/MainLoop.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/asr/src/MainLoop.cc')
-rw-r--r--source/use_case/asr/src/MainLoop.cc85
1 files changed, 1 insertions, 84 deletions
diff --git a/source/use_case/asr/src/MainLoop.cc b/source/use_case/asr/src/MainLoop.cc
index 51b0b18..a1a9540 100644
--- a/source/use_case/asr/src/MainLoop.cc
+++ b/source/use_case/asr/src/MainLoop.cc
@@ -14,15 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "hal.h" /* Brings in platform definitions. */
#include "Labels.hpp" /* For label strings. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
#include "Wav2LetterModel.hpp" /* Model class for running inference. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
#include "AsrClassifier.hpp" /* Classifier. */
#include "InputFiles.hpp" /* Generated audio clip header. */
-#include "Wav2LetterPreprocess.hpp" /* Pre-processing class. */
-#include "Wav2LetterPostprocess.hpp" /* Post-processing class. */
#include "log_macros.h"
enum opcodes
@@ -48,23 +45,9 @@ static void DisplayMenu()
fflush(stdout);
}
-/** @brief Verify input and output tensor are of certain min dimensions. */
+/** @brief Verify input and output tensor are of certain min dimensions. */
static bool VerifyTensorDimensions(const arm::app::Model& model);
-/** @brief Gets the number of MFCC features for a single window. */
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model);
-
-/** @brief Gets the number of MFCC feature vectors to be computed. */
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model);
-
-/** @brief Gets the output context length (left and right) for post-processing. */
-static uint32_t GetOutputContextLen(const arm::app::Model& model,
- uint32_t inputCtxLen);
-
-/** @brief Gets the output inner length for post-processing. */
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
- uint32_t outputCtxLen);
-
void main_loop()
{
arm::app::Wav2LetterModel model; /* Model wrapper object. */
@@ -78,21 +61,6 @@ void main_loop()
return;
}
- /* Initialise pre-processing. */
- arm::app::audio::asr::Preprocess prep(
- GetNumMfccFeatures(model),
- g_FrameLength,
- g_FrameStride,
- GetNumMfccFeatureVectors(model));
-
- /* Initialise post-processing. */
- const uint32_t outputCtxLen = GetOutputContextLen(model, g_ctxLen);
- const uint32_t blankTokenIdx = 28;
- arm::app::audio::asr::Postprocess postp(
- outputCtxLen,
- GetOutputInnerLen(model, outputCtxLen),
- blankTokenIdx);
-
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
std::vector <std::string> labels;
@@ -109,8 +77,6 @@ void main_loop()
caseContext.Set<uint32_t>("ctxLen", g_ctxLen); /* Left and right context length (MFCC feat vectors). */
caseContext.Set<const std::vector <std::string>&>("labels", labels);
caseContext.Set<arm::app::AsrClassifier&>("classifier", classifier);
- caseContext.Set<arm::app::audio::asr::Preprocess&>("preprocess", prep);
- caseContext.Set<arm::app::audio::asr::Postprocess&>("postprocess", postp);
bool executionSuccessful = true;
constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false;
@@ -184,52 +150,3 @@ static bool VerifyTensorDimensions(const arm::app::Model& model)
return true;
}
-
-static uint32_t GetNumMfccFeatures(const arm::app::Model& model)
-{
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx];
- if (0 != inputCols % 3) {
- printf_err("Number of input columns is not a multiple of 3\n");
- }
- return std::max(inputCols/3, 0);
-}
-
-static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model)
-{
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
- const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
- return std::max(inputRows, 0);
-}
-
-static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen)
-{
- const uint32_t inputRows = GetNumMfccFeatureVectors(model);
- const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
- constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
-
- /* Check to make sure that the input tensor supports the above
- * context and inner lengths. */
- if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) {
- printf_err("Input rows not compatible with ctx of %" PRIu32 "\n",
- inputCtxLen);
- return 0;
- }
-
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
-
- const float tensorColRatio = static_cast<float>(inputRows)/
- static_cast<float>(outputRows);
-
- return std::round(static_cast<float>(inputCtxLen)/tensorColRatio);
-}
-
-static uint32_t GetOutputInnerLen(const arm::app::Model& model,
- const uint32_t outputCtxLen)
-{
- constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx;
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0);
- return (outputRows - (2 * outputCtxLen));
-}