diff options
Diffstat (limited to 'source/use_case/kws_asr/include')
-rw-r--r-- | source/use_case/kws_asr/include/AsrClassifier.hpp | 66 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/AsrResult.hpp | 63 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/KwsProcessing.hpp | 138 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/KwsResult.hpp | 63 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp | 51 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/MicroNetKwsModel.hpp | 66 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/OutputDecode.hpp | 40 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/Wav2LetterMfcc.hpp | 113 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/Wav2LetterModel.hpp | 71 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp | 108 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp | 182 |
11 files changed, 0 insertions, 961 deletions
diff --git a/source/use_case/kws_asr/include/AsrClassifier.hpp b/source/use_case/kws_asr/include/AsrClassifier.hpp deleted file mode 100644 index 6ab9685..0000000 --- a/source/use_case/kws_asr/include/AsrClassifier.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef ASR_CLASSIFIER_HPP -#define ASR_CLASSIFIER_HPP - -#include "Classifier.hpp" - -namespace arm { -namespace app { - - class AsrClassifier : public Classifier { - public: - /** - * @brief Gets the top N classification results from the - * output vector. - * @param[in] outputTensor Inference output tensor from an NN model. - * @param[out] vecResults A vector of classification results - * populated by this function. - * @param[in] labels Labels vector to match classified classes - * @param[in] topNCount Number of top classifications to pick. - * @param[in] use_softmax Whether softmax scaling should be applied to model output. - * @return true if successful, false otherwise. - **/ - bool GetClassificationResults( - TfLiteTensor* outputTensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount, - bool use_softmax = false) override; - - private: - - /** - * @brief Utility function that gets the top 1 classification results from the - * output tensor (vector of vector). - * @param[in] tensor Inference output tensor from an NN model. - * @param[out] vecResults A vector of classification results - * populated by this function. - * @param[in] labels Labels vector to match classified classes. - * @param[in] scale Quantization scale. - * @param[in] zeroPoint Quantization zero point. - * @return true if successful, false otherwise. - **/ - template<typename T> - bool GetTopResults(TfLiteTensor* tensor, - std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, double scale, double zeroPoint); - }; - -} /* namespace app */ -} /* namespace arm */ - -#endif /* ASR_CLASSIFIER_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/include/AsrResult.hpp b/source/use_case/kws_asr/include/AsrResult.hpp deleted file mode 100644 index 25fa9e8..0000000 --- a/source/use_case/kws_asr/include/AsrResult.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef ASR_RESULT_HPP -#define ASR_RESULT_HPP - -#include "ClassificationResult.hpp" - -#include <vector> - -namespace arm { -namespace app { -namespace asr { - - using ResultVec = std::vector<arm::app::ClassificationResult>; - - /* Structure for holding asr result. */ - class AsrResult { - - public: - ResultVec m_resultVec; /* Container for "thresholded" classification results. */ - float m_timeStamp; /* Audio timestamp for this result. */ - uint32_t m_inferenceNumber; /* Corresponding inference number. */ - float m_threshold; /* Threshold value for `m_resultVec` */ - - AsrResult() = delete; - AsrResult(ResultVec& resultVec, - const float timestamp, - const uint32_t inferenceIdx, - const float scoreThreshold) { - - this->m_threshold = scoreThreshold; - this->m_timeStamp = timestamp; - this->m_inferenceNumber = inferenceIdx; - - this->m_resultVec = ResultVec(); - for (auto& i : resultVec) { - if (i.m_normalisedVal >= this->m_threshold) { - this->m_resultVec.emplace_back(i); - } - } - } - ~AsrResult() = default; - }; - -} /* namespace asr */ -} /* namespace app */ -} /* namespace arm */ - -#endif /* ASR_RESULT_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/include/KwsProcessing.hpp b/source/use_case/kws_asr/include/KwsProcessing.hpp deleted file mode 100644 index d3de3b3..0000000 --- a/source/use_case/kws_asr/include/KwsProcessing.hpp +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2022 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_PROCESSING_HPP -#define KWS_PROCESSING_HPP - -#include <AudioUtils.hpp> -#include "BaseProcessing.hpp" -#include "Model.hpp" -#include "Classifier.hpp" -#include "MicroNetKwsMfcc.hpp" - -#include <functional> - -namespace arm { -namespace app { - - /** - * @brief Pre-processing class for Keyword Spotting use case. - * Implements methods declared by BasePreProcess and anything else needed - * to populate input tensors ready for inference. - */ - class KwsPreProcess : public BasePreProcess { - - public: - /** - * @brief Constructor - * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. - * @param[in] numFeatures How many MFCC features to use. - * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated - * for an inference. - * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when - * sliding a window through the audio sample. - * @param[in] mfccFrameStride Number of audio samples between consecutive windows. - **/ - explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames, - int mfccFrameLength, int mfccFrameStride); - - /** - * @brief Should perform pre-processing of 'raw' input audio data and load it into - * TFLite Micro input tensors ready for inference. - * @param[in] input Pointer to the data that pre-processing will work on. - * @param[in] inputSize Size of the input data. - * @return true if successful, false otherwise. - **/ - bool DoPreProcess(const void* input, size_t inputSize) override; - - size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */ - size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */ - size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */ - - private: - TfLiteTensor* m_inputTensor; /* Model input tensor. */ - const int m_mfccFrameLength; - const int m_mfccFrameStride; - const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */ - - audio::MicroNetKwsMFCC m_mfcc; - audio::SlidingWindow<const int16_t> m_mfccSlidingWindow; - size_t m_numMfccVectorsInAudioStride; - size_t m_numReusedMfccVectors; - std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator; - - /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MFCC data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: MFCC calculator provided as input must have the same life scope as returned function. - * - * @param[in] mfcc MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). - * @return Function to be called providing audio sample and sliding window index. - */ - std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, - TfLiteTensor* inputTensor, - size_t cacheSize); - - template<class T> - std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function<std::vector<T> (std::vector<int16_t>& )> compute); - }; - - /** - * @brief Post-processing class for Keyword Spotting use case. - * Implements methods declared by BasePostProcess and anything else needed - * to populate result vector. - */ - class KwsPostProcess : public BasePostProcess { - - private: - TfLiteTensor* m_outputTensor; /* Model output tensor. */ - Classifier& m_kwsClassifier; /* KWS Classifier object. */ - const std::vector<std::string>& m_labels; /* KWS Labels. */ - std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */ - - public: - /** - * @brief Constructor - * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in/out] results Vector of classification results to store decoded outputs. - **/ - KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, - const std::vector<std::string>& labels, - std::vector<ClassificationResult>& results); - - /** - * @brief Should perform post-processing of the result of inference then - * populate KWS result data for any later use. - * @return true if successful, false otherwise. - **/ - bool DoPostProcess() override; - }; - -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_PROCESSING_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/include/KwsResult.hpp b/source/use_case/kws_asr/include/KwsResult.hpp deleted file mode 100644 index 45bb790..0000000 --- a/source/use_case/kws_asr/include/KwsResult.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_RESULT_HPP -#define KWS_RESULT_HPP - -#include "ClassificationResult.hpp" - -#include <vector> - -namespace arm { -namespace app { -namespace kws { - - using ResultVec = std::vector < arm::app::ClassificationResult >; - - /* Structure for holding kws result. */ - class KwsResult { - - public: - ResultVec m_resultVec; /* Container for "thresholded" classification results. */ - float m_timeStamp; /* Audio timestamp for this result. */ - uint32_t m_inferenceNumber; /* Corresponding inference number. */ - float m_threshold; /* Threshold value for `m_resultVec.` */ - - KwsResult() = delete; - KwsResult(ResultVec& resultVec, - const float timestamp, - const uint32_t inferenceIdx, - const float scoreThreshold) { - - this->m_threshold = scoreThreshold; - this->m_timeStamp = timestamp; - this->m_inferenceNumber = inferenceIdx; - - this->m_resultVec = ResultVec(); - for (auto & i : resultVec) { - if (i.m_normalisedVal >= this->m_threshold) { - this->m_resultVec.emplace_back(i); - } - } - } - ~KwsResult() = default; - }; - -} /* namespace kws */ -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_RESULT_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp deleted file mode 100644 index af6ba5f..0000000 --- a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2021-2022 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_MICRONET_MFCC_HPP -#define KWS_ASR_MICRONET_MFCC_HPP - -#include "Mfcc.hpp" - -namespace arm { -namespace app { -namespace audio { - - /* Class to provide MicroNet specific MFCC calculation requirements. */ - class MicroNetKwsMFCC : public MFCC { - - public: - static constexpr uint32_t ms_defaultSamplingFreq = 16000; - static constexpr uint32_t ms_defaultNumFbankBins = 40; - static constexpr uint32_t ms_defaultMelLoFreq = 20; - static constexpr uint32_t ms_defaultMelHiFreq = 4000; - static constexpr bool ms_defaultUseHtkMethod = true; - - - explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen) - : MFCC(MfccParams( - ms_defaultSamplingFreq, ms_defaultNumFbankBins, - ms_defaultMelLoFreq, ms_defaultMelHiFreq, - numFeats, frameLen, ms_defaultUseHtkMethod)) - {} - MicroNetKwsMFCC() = delete; - ~MicroNetKwsMFCC() = default; - }; - -} /* namespace audio */ -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_ASR_MICRONET_MFCC_HPP */ diff --git a/source/use_case/kws_asr/include/MicroNetKwsModel.hpp b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp deleted file mode 100644 index 22cf916..0000000 --- a/source/use_case/kws_asr/include/MicroNetKwsModel.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_MICRONETMODEL_HPP -#define KWS_ASR_MICRONETMODEL_HPP - -#include "Model.hpp" - -namespace arm { -namespace app { -namespace kws { - extern const int g_FrameLength; - extern const int g_FrameStride; - extern const float g_ScoreThreshold; - extern const uint32_t g_NumMfcc; - extern const uint32_t g_NumAudioWins; -} /* namespace kws */ -} /* namespace app */ -} /* namespace arm */ - -namespace arm { -namespace app { - class MicroNetKwsModel : public Model { - public: - /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 1; - static constexpr uint32_t ms_inputColsIdx = 2; - static constexpr uint32_t ms_outputRowsIdx = 2; - static constexpr uint32_t ms_outputColsIdx = 3; - - protected: - /** @brief Gets the reference to op resolver interface class. */ - const tflite::MicroOpResolver& GetOpResolver() override; - - /** @brief Adds operations to the op resolver instance. */ - bool EnlistOperations() override; - - const uint8_t* ModelPointer() override; - - size_t ModelSize() override; - - private: - /* Maximum number of individual operations that can be enlisted. */ - static constexpr int ms_maxOpCnt = 7; - - /* A mutable op resolver instance. */ - tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver; - }; - -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_ASR_MICRONETMODEL_HPP */ diff --git a/source/use_case/kws_asr/include/OutputDecode.hpp b/source/use_case/kws_asr/include/OutputDecode.hpp deleted file mode 100644 index cea2c33..0000000 --- a/source/use_case/kws_asr/include/OutputDecode.hpp +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_OUTPUT_DECODE_HPP -#define KWS_ASR_OUTPUT_DECODE_HPP - -#include "AsrClassifier.hpp" - -namespace arm { -namespace app { -namespace audio { -namespace asr { - - /** - * @brief Gets the top N classification results from the - * output vector. - * @param[in] vecResults Label output from classifier. - * @return true if successful, false otherwise. - **/ - std::string DecodeOutput(const std::vector<ClassificationResult>& vecResults); - -} /* namespace asr */ -} /* namespace audio */ -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_ASR_OUTPUT_DECODE_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/include/Wav2LetterMfcc.hpp b/source/use_case/kws_asr/include/Wav2LetterMfcc.hpp deleted file mode 100644 index 75d75da..0000000 --- a/source/use_case/kws_asr/include/Wav2LetterMfcc.hpp +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_WAV2LET_MFCC_HPP -#define KWS_ASR_WAV2LET_MFCC_HPP - -#include "Mfcc.hpp" - -namespace arm { -namespace app { -namespace audio { - - /* Class to provide Wav2Letter specific MFCC calculation requirements. */ - class Wav2LetterMFCC : public MFCC { - - public: - static constexpr uint32_t ms_defaultSamplingFreq = 16000; - static constexpr uint32_t ms_defaultNumFbankBins = 128; - static constexpr uint32_t ms_defaultMelLoFreq = 0; - static constexpr uint32_t ms_defaultMelHiFreq = 8000; - static constexpr bool ms_defaultUseHtkMethod = false; - - explicit Wav2LetterMFCC(const size_t numFeats, const size_t frameLen) - : MFCC(MfccParams( - ms_defaultSamplingFreq, ms_defaultNumFbankBins, - ms_defaultMelLoFreq, ms_defaultMelHiFreq, - numFeats, frameLen, ms_defaultUseHtkMethod)) - {} - - Wav2LetterMFCC() = delete; - ~Wav2LetterMFCC() = default; - - protected: - - /** - * @brief Overrides base class implementation of this function. - * @param[in] fftVec Vector populated with FFT magnitudes. - * @param[in] melFilterBank 2D Vector with filter bank weights. - * @param[in] filterBankFilterFirst Vector containing the first indices of filter bank - * to be used for each bin. - * @param[in] filterBankFilterLast Vector containing the last indices of filter bank - * to be used for each bin. - * @param[out] melEnergies Pre-allocated vector of MEL energies to be - * populated. - * @return true if successful, false otherwise. - */ - bool ApplyMelFilterBank( - std::vector<float>& fftVec, - std::vector<std::vector<float>>& melFilterBank, - std::vector<uint32_t>& filterBankFilterFirst, - std::vector<uint32_t>& filterBankFilterLast, - std::vector<float>& melEnergies) override; - - /** - * @brief Override for the base class implementation convert mel - * energies to logarithmic scale. The difference from - * default behaviour is that the power is converted to dB - * and subsequently clamped. - * @param[in,out] melEnergies 1D vector of Mel energies. - **/ - void ConvertToLogarithmicScale( - std::vector<float>& melEnergies) override; - - /** - * @brief Create a matrix used to calculate Discrete Cosine - * Transform. Override for the base class' default - * implementation as the first and last elements - * use a different normaliser. - * @param[in] inputLength Input length of the buffer on which - * DCT will be performed. - * @param[in] coefficientCount Total coefficients per input length. - * @return 1D vector with inputLength x coefficientCount elements - * populated with DCT coefficients. - */ - std::vector<float> CreateDCTMatrix( - int32_t inputLength, - int32_t coefficientCount) override; - - /** - * @brief Given the low and high Mel values, get the normaliser - * for weights to be applied when populating the filter - * bank. Override for the base class implementation. - * @param[in] leftMel Low Mel frequency value. - * @param[in] rightMel High Mel frequency value. - * @param[in] useHTKMethod Bool to signal if HTK method is to be - * used for calculation. - * @return Value to use for normalising. - */ - float GetMelFilterBankNormaliser( - const float& leftMel, - const float& rightMel, - bool useHTKMethod) override; - - }; - -} /* namespace audio */ -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_ASR_WAV2LET_MFCC_HPP */ diff --git a/source/use_case/kws_asr/include/Wav2LetterModel.hpp b/source/use_case/kws_asr/include/Wav2LetterModel.hpp deleted file mode 100644 index 0e1adc5..0000000 --- a/source/use_case/kws_asr/include/Wav2LetterModel.hpp +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2021-2022 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_WAV2LETTER_MODEL_HPP -#define KWS_ASR_WAV2LETTER_MODEL_HPP - -#include "Model.hpp" - -namespace arm { -namespace app { -namespace asr { - extern const int g_FrameLength; - extern const int g_FrameStride; - extern const float g_ScoreThreshold; - extern const int g_ctxLen; -} /* namespace asr */ -} /* namespace app */ -} /* namespace arm */ - -namespace arm { -namespace app { - - class Wav2LetterModel : public Model { - - public: - /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 1; - static constexpr uint32_t ms_inputColsIdx = 2; - static constexpr uint32_t ms_outputRowsIdx = 2; - static constexpr uint32_t ms_outputColsIdx = 3; - - /* Model specific constants. */ - static constexpr uint32_t ms_blankTokenIdx = 28; - static constexpr uint32_t ms_numMfccFeatures = 13; - - protected: - /** @brief Gets the reference to op resolver interface class. */ - const tflite::MicroOpResolver& GetOpResolver() override; - - /** @brief Adds operations to the op resolver instance. */ - bool EnlistOperations() override; - - const uint8_t* ModelPointer() override; - - size_t ModelSize() override; - - private: - /* Maximum number of individual operations that can be enlisted. */ - static constexpr int ms_maxOpCnt = 5; - - /* A mutable op resolver instance. */ - tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver; - }; - -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_ASR_WAV2LETTER_MODEL_HPP */ diff --git a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp deleted file mode 100644 index d1bc9a2..0000000 --- a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) 2021-2022 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_WAV2LETTER_POSTPROCESS_HPP -#define KWS_ASR_WAV2LETTER_POSTPROCESS_HPP - -#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ -#include "BaseProcessing.hpp" -#include "AsrClassifier.hpp" -#include "AsrResult.hpp" -#include "log_macros.h" - -namespace arm { -namespace app { - - /** - * @brief Helper class to manage tensor post-processing for "wav2letter" - * output. - */ - class AsrPostProcess : public BasePostProcess { - public: - bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ - - /** - * @brief Constructor - * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. - * @param[in] classifier Object used to get top N results from classification. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in/out] result Vector of classification results to store decoded outputs. - * @param[in] outputContextLen Left/right context length for output tensor. - * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. - * @param[in] reductionAxis The axis that the logits of each time step is on. - **/ - AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, - const std::vector<std::string>& labels, asr::ResultVec& result, - uint32_t outputContextLen, - uint32_t blankTokenIdx, uint32_t reductionAxis); - - /** - * @brief Should perform post-processing of the result of inference then - * populate ASR result data for any later use. - * @return true if successful, false otherwise. - **/ - bool DoPostProcess() override; - - /** @brief Gets the output inner length for post-processing. */ - static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen); - - /** @brief Gets the output context length (left/right) for post-processing. */ - static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen); - - /** @brief Gets the number of feature vectors to be computed. */ - static uint32_t GetNumFeatureVectors(const Model& model); - - private: - AsrClassifier& m_classifier; /* ASR Classifier object. */ - TfLiteTensor* m_outputTensor; /* Model output tensor. */ - const std::vector<std::string>& m_labels; /* ASR Labels. */ - asr::ResultVec & m_results; /* Results vector for a single inference. */ - uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */ - uint32_t m_outputInnerLen; /* Length of output inner context. */ - uint32_t m_totalLen; /* Total length of the required axis. */ - uint32_t m_countIterations; /* Current number of iterations. */ - uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ - uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */ - - /** - * @brief Checks if the tensor and axis index are valid - * inputs to the object - based on how it has been initialised. - * @return true if valid, false otherwise. - */ - bool IsInputValid(TfLiteTensor* tensor, - uint32_t axisIdx) const; - - /** - * @brief Gets the tensor data element size in bytes based - * on the tensor type. - * @return Size in bytes, 0 if not supported. - */ - static uint32_t GetTensorElementSize(TfLiteTensor* tensor); - - /** - * @brief Erases sections from the data assuming row-wise - * arrangement along the context axis. - * @return true if successful, false otherwise. - */ - bool EraseSectionsRowWise(uint8_t* ptrData, - uint32_t strideSzBytes, - bool lastIteration); - }; - -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_ASR_WAV2LETTER_POSTPROCESS_HPP */
\ No newline at end of file diff --git a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp deleted file mode 100644 index 1224c23..0000000 --- a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Copyright (c) 2021-2022 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_WAV2LETTER_PREPROCESS_HPP -#define KWS_ASR_WAV2LETTER_PREPROCESS_HPP - -#include "Wav2LetterModel.hpp" -#include "Wav2LetterMfcc.hpp" -#include "AudioUtils.hpp" -#include "DataStructures.hpp" -#include "BaseProcessing.hpp" -#include "log_macros.h" - -namespace arm { -namespace app { - - /* Class to facilitate pre-processing calculation for Wav2Letter model - * for ASR. */ - using AudioWindow = audio::SlidingWindow<const int16_t>; - - class AsrPreProcess : public BasePreProcess { - public: - /** - * @brief Constructor. - * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. - * @param[in] numMfccFeatures Number of MFCC features per window. - * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated - * for an inference. - * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. - * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. - */ - AsrPreProcess(TfLiteTensor* inputTensor, - uint32_t numMfccFeatures, - uint32_t numFeatureFrames, - uint32_t mfccWindowLen, - uint32_t mfccWindowStride); - - /** - * @brief Calculates the features required from audio data. This - * includes MFCC, first and second order deltas, - * normalisation and finally, quantisation. The tensor is - * populated with features from a given window placed along - * in a single row. - * @param[in] audioData Pointer to the first element of audio data. - * @param[in] audioDataLen Number of elements in the audio data. - * @return true if successful, false in case of error. - */ - bool DoPreProcess(const void* audioData, size_t audioDataLen) override; - - protected: - /** - * @brief Computes the first and second order deltas for the - * MFCC buffers - they are assumed to be populated. - * - * @param[in] mfcc MFCC buffers. - * @param[out] delta1 Result of the first diff computation. - * @param[out] delta2 Result of the second diff computation. - * @return true if successful, false otherwise. - */ - static bool ComputeDeltas(Array2d<float>& mfcc, - Array2d<float>& delta1, - Array2d<float>& delta2); - - /** - * @brief Given a 2D vector of floats, rescale it to have mean of 0 and - * standard deviation of 1. - * @param[in,out] vec Vector of vector of floats. - */ - static void StandardizeVecF32(Array2d<float>& vec); - - /** - * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1. - */ - void Standarize(); - - /** - * @brief Given the quantisation and data type limits, computes - * the quantised values of a floating point input data. - * @param[in] elem Element to be quantised. - * @param[in] quantScale Scale. - * @param[in] quantOffset Offset. - * @param[in] minVal Numerical limit - minimum. - * @param[in] maxVal Numerical limit - maximum. - * @return Floating point quantised value. - */ - static float GetQuantElem( - float elem, - float quantScale, - int quantOffset, - float minVal, - float maxVal); - - /** - * @brief Quantises the MFCC and delta buffers, and places them - * in the output buffer. While doing so, it transposes - * the data. Reason: Buffers in this class are arranged - * for "time" axis to be row major. Primary reason for - * this being the convolution speed up (as we can use - * contiguous memory). The output, however, requires the - * time axis to be in column major arrangement. - * @param[in] outputBuf Pointer to the output buffer. - * @param[in] outputBufSz Output buffer's size. - * @param[in] quantScale Quantisation scale. - * @param[in] quantOffset Quantisation offset. - */ - template <typename T> - bool Quantise( - T* outputBuf, - const uint32_t outputBufSz, - const float quantScale, - const int quantOffset) - { - /* Check the output size will fit everything. */ - if (outputBufSz < (this->m_mfccBuf.size(0) * 3 * sizeof(T))) { - printf_err("Tensor size too small for features\n"); - return false; - } - - /* Populate. */ - T* outputBufMfcc = outputBuf; - T* outputBufD1 = outputBuf + this->m_numMfccFeats; - T* outputBufD2 = outputBufD1 + this->m_numMfccFeats; - const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ - - const float minVal = std::numeric_limits<T>::min(); - const float maxVal = std::numeric_limits<T>::max(); - - /* Need to transpose while copying and concatenating the tensor. */ - for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { - for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { - *outputBufMfcc++ = static_cast<T>(AsrPreProcess::GetQuantElem( - this->m_mfccBuf(i, j), quantScale, - quantOffset, minVal, maxVal)); - *outputBufD1++ = static_cast<T>(AsrPreProcess::GetQuantElem( - this->m_delta1Buf(i, j), quantScale, - quantOffset, minVal, maxVal)); - *outputBufD2++ = static_cast<T>(AsrPreProcess::GetQuantElem( - this->m_delta2Buf(i, j), quantScale, - quantOffset, minVal, maxVal)); - } - outputBufMfcc += ptrIncr; - outputBufD1 += ptrIncr; - outputBufD2 += ptrIncr; - } - - return true; - } - - private: - audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */ - TfLiteTensor* m_inputTensor; /* Model input tensor. */ - - /* Actual buffers to be populated. */ - Array2d<float> m_mfccBuf; /* Contiguous buffer 1D: MFCC */ - Array2d<float> m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ - Array2d<float> m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ - - uint32_t m_mfccWindowLen; /* Window length for MFCC. */ - uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */ - uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ - uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */ - AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */ - - }; - -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_ASR_WAV2LETTER_PREPROCESS_HPP */
\ No newline at end of file |