From c291144b7f08c21d08cdaf79cc64dc420ca70070 Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Fri, 22 Apr 2022 09:08:21 +0100 Subject: MLECO-3077: Add ASR use case API * Minor adjustments to doc strings in KWS * Remove unused score threshold in KWS Signed-off-by: Richard Burton Change-Id: Ie1c5bf6f7bdbebb853b6a10cb7ba1c4a1d9a76c9 --- .../use_case/asr/include/Wav2LetterPostprocess.hpp | 115 ++++++++++----------- 1 file changed, 56 insertions(+), 59 deletions(-) (limited to 'source/use_case/asr/include/Wav2LetterPostprocess.hpp') diff --git a/source/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/use_case/asr/include/Wav2LetterPostprocess.hpp index 29eb548..45defa5 100644 --- a/source/use_case/asr/include/Wav2LetterPostprocess.hpp +++ b/source/use_case/asr/include/Wav2LetterPostprocess.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,93 +17,90 @@ #ifndef ASR_WAV2LETTER_POSTPROCESS_HPP #define ASR_WAV2LETTER_POSTPROCESS_HPP -#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "BaseProcessing.hpp" +#include "AsrClassifier.hpp" +#include "AsrResult.hpp" #include "log_macros.h" namespace arm { namespace app { -namespace audio { -namespace asr { /** * @brief Helper class to manage tensor post-processing for "wav2letter" * output. */ - class Postprocess { + class ASRPostProcess : public BasePostProcess { public: + bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ + /** - * @brief Constructor - * @param[in] contextLen Left and right context length for - * output tensor. - * @param[in] innerLen This is the length of the section - * between left and right context. - * @param[in] blankTokenIdx Blank token index. + * @brief Constructor + * @param[in] outputTensor Pointer to the output Tensor. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] result Vector of classification results to store decoded outputs. + * @param[in] outputContextLen Left/right context length for output tensor. + * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. + * @param[in] reductionAxis The axis that the logits of each time step is on. **/ - Postprocess(uint32_t contextLen, - uint32_t innerLen, - uint32_t blankTokenIdx); - - Postprocess() = delete; - ~Postprocess() = default; + ASRPostProcess(AsrClassifier& classifier, TfLiteTensor* outputTensor, + const std::vector& labels, asr::ResultVec& result, + uint32_t outputContextLen, + uint32_t blankTokenIdx, uint32_t reductionAxis); /** - * @brief Erases the required part of the tensor based - * on context lengths set up during initialisation. - * @param[in] tensor Pointer to the tensor. - * @param[in] axisIdx Index of the axis on which erase is - * performed. - * @param[in] lastIteration Flag to signal this is the - * last iteration in which case - * the right context is preserved. - * @return true if successful, false otherwise. - */ - bool Invoke(TfLiteTensor* tensor, - uint32_t axisIdx, - bool lastIteration = false); + * @brief Should perform post-processing of the result of inference then + * populate ASR result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + + /** @brief Gets the output inner length for post-processing. */ + static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen); + + /** @brief Gets the output context length (left/right) for post-processing. */ + static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen); + + /** @brief Gets the number of feature vectors to be computed. */ + static uint32_t GetNumFeatureVectors(const Model& model); private: - uint32_t m_contextLen; /* lengths of left and right contexts. */ - uint32_t m_innerLen; /* Length of inner context. */ - uint32_t m_totalLen; /* Total length of the required axis. */ - uint32_t m_countIterations; /* Current number of iterations. */ - uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + AsrClassifier& m_classifier; /* ASR Classifier object. */ + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + const std::vector& m_labels; /* ASR Labels. */ + asr::ResultVec & m_results; /* Results vector for a single inference. */ + uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */ + uint32_t m_outputInnerLen; /* Length of output inner context. */ + uint32_t m_totalLen; /* Total length of the required axis. */ + uint32_t m_countIterations; /* Current number of iterations. */ + uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */ + /** - * @brief Checks if the tensor and axis index are valid - * inputs to the object - based on how it has been - * initialised. - * @return true if valid, false otherwise. + * @brief Checks if the tensor and axis index are valid + * inputs to the object - based on how it has been initialised. + * @return true if valid, false otherwise. */ bool IsInputValid(TfLiteTensor* tensor, - const uint32_t axisIdx) const; + uint32_t axisIdx) const; /** - * @brief Gets the tensor data element size in bytes based - * on the tensor type. - * @return Size in bytes, 0 if not supported. + * @brief Gets the tensor data element size in bytes based + * on the tensor type. + * @return Size in bytes, 0 if not supported. */ static uint32_t GetTensorElementSize(TfLiteTensor* tensor); /** - * @brief Erases sections from the data assuming row-wise - * arrangement along the context axis. - * @return true if successful, false otherwise. + * @brief Erases sections from the data assuming row-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. */ bool EraseSectionsRowWise(uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration); - - /** - * @brief Erases sections from the data assuming col-wise - * arrangement along the context axis. - * @return true if successful, false otherwise. - */ - static bool EraseSectionsColWise(const uint8_t* ptrData, - const uint32_t strideSzBytes, - const bool lastIteration); + uint32_t strideSzBytes, + bool lastIteration); }; -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */ -- cgit v1.2.1