From aa4bcb14d0cbee910331545dd2fc086b58c37170 Mon Sep 17 00:00:00 2001 From: Kshitij Sisodia Date: Fri, 6 May 2022 09:13:03 +0100 Subject: MLECO-3183: Refactoring application sources Platform agnostic application sources are moved into application api module with their own independent CMake projects. Changes for MLECO-3080 also included - they create CMake projects individial API's (again, platform agnostic) that dependent on the common logic. The API for KWS_API "joint" API has been removed and now the use case relies on individual KWS, and ASR API libraries. Change-Id: I1f7748dc767abb3904634a04e0991b74ac7b756d Signed-off-by: Kshitij Sisodia --- source/application/api/use_case/asr/CMakeLists.txt | 43 +++++ .../api/use_case/asr/include/AsrClassifier.hpp | 63 ++++++ .../api/use_case/asr/include/AsrResult.hpp | 63 ++++++ .../api/use_case/asr/include/OutputDecode.hpp | 40 ++++ .../api/use_case/asr/include/Wav2LetterMfcc.hpp | 109 +++++++++++ .../api/use_case/asr/include/Wav2LetterModel.hpp | 67 +++++++ .../use_case/asr/include/Wav2LetterPostprocess.hpp | 109 +++++++++++ .../use_case/asr/include/Wav2LetterPreprocess.hpp | 182 ++++++++++++++++++ .../api/use_case/asr/src/AsrClassifier.cc | 144 ++++++++++++++ .../api/use_case/asr/src/OutputDecode.cc | 47 +++++ .../api/use_case/asr/src/Wav2LetterMfcc.cc | 141 ++++++++++++++ .../api/use_case/asr/src/Wav2LetterModel.cc | 42 ++++ .../api/use_case/asr/src/Wav2LetterPostprocess.cc | 214 +++++++++++++++++++++ .../api/use_case/asr/src/Wav2LetterPreprocess.cc | 208 ++++++++++++++++++++ 14 files changed, 1472 insertions(+) create mode 100644 source/application/api/use_case/asr/CMakeLists.txt create mode 100644 source/application/api/use_case/asr/include/AsrClassifier.hpp create mode 100644 source/application/api/use_case/asr/include/AsrResult.hpp create mode 100644 source/application/api/use_case/asr/include/OutputDecode.hpp create mode 100644 source/application/api/use_case/asr/include/Wav2LetterMfcc.hpp create mode 100644 source/application/api/use_case/asr/include/Wav2LetterModel.hpp create mode 100644 source/application/api/use_case/asr/include/Wav2LetterPostprocess.hpp create mode 100644 source/application/api/use_case/asr/include/Wav2LetterPreprocess.hpp create mode 100644 source/application/api/use_case/asr/src/AsrClassifier.cc create mode 100644 source/application/api/use_case/asr/src/OutputDecode.cc create mode 100644 source/application/api/use_case/asr/src/Wav2LetterMfcc.cc create mode 100644 source/application/api/use_case/asr/src/Wav2LetterModel.cc create mode 100644 source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc create mode 100644 source/application/api/use_case/asr/src/Wav2LetterPreprocess.cc (limited to 'source/application/api/use_case/asr') diff --git a/source/application/api/use_case/asr/CMakeLists.txt b/source/application/api/use_case/asr/CMakeLists.txt new file mode 100644 index 0000000..77e3d6a --- /dev/null +++ b/source/application/api/use_case/asr/CMakeLists.txt @@ -0,0 +1,43 @@ +#---------------------------------------------------------------------------- +# Copyright (c) 2022 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#---------------------------------------------------------------------------- +######################################################### +# AUTOMATIC SPEECH RECOGNITION API library # +######################################################### +cmake_minimum_required(VERSION 3.15.6) + +set(ASR_API_TARGET asr_api) +project(${ASR_API_TARGET} + DESCRIPTION "Automatic speech recognition use case API library" + LANGUAGES C CXX) + +# Create static library +add_library(${ASR_API_TARGET} STATIC + src/Wav2LetterPreprocess.cc + src/Wav2LetterPostprocess.cc + src/Wav2LetterMfcc.cc + src/AsrClassifier.cc + src/OutputDecode.cc + src/Wav2LetterModel.cc) + +target_include_directories(${ASR_API_TARGET} PUBLIC include) + +target_link_libraries(${ASR_API_TARGET} PUBLIC common_api) + +message(STATUS "*******************************************************") +message(STATUS "Library : " ${ASR_API_TARGET}) +message(STATUS "CMAKE_SYSTEM_PROCESSOR : " ${CMAKE_SYSTEM_PROCESSOR}) +message(STATUS "*******************************************************") diff --git a/source/application/api/use_case/asr/include/AsrClassifier.hpp b/source/application/api/use_case/asr/include/AsrClassifier.hpp new file mode 100644 index 0000000..a07a721 --- /dev/null +++ b/source/application/api/use_case/asr/include/AsrClassifier.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_CLASSIFIER_HPP +#define ASR_CLASSIFIER_HPP + +#include "Classifier.hpp" + +namespace arm { +namespace app { + + class AsrClassifier : public Classifier { + public: + /** + * @brief Gets the top N classification results from the + * output vector. + * @param[in] outputTensor Inference output tensor from an NN model. + * @param[out] vecResults A vector of classification results + * populated by this function. + * @param[in] labels Labels vector to match classified classes + * @param[in] topNCount Number of top classifications to pick. + * @param[in] use_softmax Whether softmax scaling should be applied to model output. + * @return true if successful, false otherwise. + **/ + bool GetClassificationResults(TfLiteTensor* outputTensor, + std::vector& vecResults, + const std::vector& labels, + uint32_t topNCount, bool use_softmax = false) override; + + private: + /** + * @brief Utility function that gets the top 1 classification results from the + * output tensor (vector of vector). + * @param[in] tensor Inference output tensor from an NN model. + * @param[out] vecResults Vector of classification results populated by this function. + * @param[in] labels Labels vector to match classified classes. + * @param[in] scale Quantization scale. + * @param[in] zeroPoint Quantization zero point. + * @return true if successful, false otherwise. + **/ + template + bool GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector& labels, double scale, double zeroPoint); + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_CLASSIFIER_HPP */ \ No newline at end of file diff --git a/source/application/api/use_case/asr/include/AsrResult.hpp b/source/application/api/use_case/asr/include/AsrResult.hpp new file mode 100644 index 0000000..ed826d0 --- /dev/null +++ b/source/application/api/use_case/asr/include/AsrResult.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_RESULT_HPP +#define ASR_RESULT_HPP + +#include "ClassificationResult.hpp" + +#include + +namespace arm { +namespace app { +namespace asr { + + using ResultVec = std::vector; + + /* Structure for holding ASR result. */ + class AsrResult { + + public: + ResultVec m_resultVec; /* Container for "thresholded" classification results. */ + float m_timeStamp; /* Audio timestamp for this result. */ + uint32_t m_inferenceNumber; /* Corresponding inference number. */ + float m_threshold; /* Threshold value for `m_resultVec.` */ + + AsrResult() = delete; + AsrResult(ResultVec& resultVec, + const float timestamp, + const uint32_t inferenceIdx, + const float scoreThreshold) { + + this->m_threshold = scoreThreshold; + this->m_timeStamp = timestamp; + this->m_inferenceNumber = inferenceIdx; + + this->m_resultVec = ResultVec(); + for (auto& i : resultVec) { + if (i.m_normalisedVal >= this->m_threshold) { + this->m_resultVec.emplace_back(i); + } + } + } + ~AsrResult() = default; + }; + +} /* namespace asr */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_RESULT_HPP */ \ No newline at end of file diff --git a/source/application/api/use_case/asr/include/OutputDecode.hpp b/source/application/api/use_case/asr/include/OutputDecode.hpp new file mode 100644 index 0000000..9d39057 --- /dev/null +++ b/source/application/api/use_case/asr/include/OutputDecode.hpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_OUTPUT_DECODE_HPP +#define ASR_OUTPUT_DECODE_HPP + +#include "AsrClassifier.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + /** + * @brief Gets the top N classification results from the + * output vector. + * @param[in] vecResults Label output from classifier. + * @return true if successful, false otherwise. + **/ + std::string DecodeOutput(const std::vector& vecResults); + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_OUTPUT_DECODE_HPP */ \ No newline at end of file diff --git a/source/application/api/use_case/asr/include/Wav2LetterMfcc.hpp b/source/application/api/use_case/asr/include/Wav2LetterMfcc.hpp new file mode 100644 index 0000000..b5a21d3 --- /dev/null +++ b/source/application/api/use_case/asr/include/Wav2LetterMfcc.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_WAV2LETTER_MFCC_HPP +#define ASR_WAV2LETTER_MFCC_HPP + +#include "Mfcc.hpp" + +namespace arm { +namespace app { +namespace audio { + + /* Class to provide Wav2Letter specific MFCC calculation requirements. */ + class Wav2LetterMFCC : public MFCC { + + public: + static constexpr uint32_t ms_defaultSamplingFreq = 16000; + static constexpr uint32_t ms_defaultNumFbankBins = 128; + static constexpr uint32_t ms_defaultMelLoFreq = 0; + static constexpr uint32_t ms_defaultMelHiFreq = 8000; + static constexpr bool ms_defaultUseHtkMethod = false; + + explicit Wav2LetterMFCC(const size_t numFeats, const size_t frameLen) + : MFCC(MfccParams( + ms_defaultSamplingFreq, ms_defaultNumFbankBins, + ms_defaultMelLoFreq, ms_defaultMelHiFreq, + numFeats, frameLen, ms_defaultUseHtkMethod)) + {} + + Wav2LetterMFCC() = delete; + ~Wav2LetterMFCC() = default; + + protected: + + /** + * @brief Overrides base class implementation of this function. + * @param[in] fftVec Vector populated with FFT magnitudes + * @param[in] melFilterBank 2D Vector with filter bank weights + * @param[in] filterBankFilterFirst Vector containing the first indices of filter bank + * to be used for each bin. + * @param[in] filterBankFilterLast Vector containing the last indices of filter bank + * to be used for each bin. + * @param[out] melEnergies Pre-allocated vector of MEL energies to be + * populated. + * @return true if successful, false otherwise + */ + bool ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) override; + + /** + * @brief Override for the base class implementation convert mel + * energies to logarithmic scale. The difference from + * default behaviour is that the power is converted to dB + * and subsequently clamped. + * @param[in,out] melEnergies 1D vector of Mel energies + **/ + void ConvertToLogarithmicScale(std::vector& melEnergies) override; + + /** + * @brief Create a matrix used to calculate Discrete Cosine + * Transform. Override for the base class' default + * implementation as the first and last elements + * use a different normaliser. + * @param[in] inputLength input length of the buffer on which + * DCT will be performed + * @param[in] coefficientCount Total coefficients per input length. + * @return 1D vector with inputLength x coefficientCount elements + * populated with DCT coefficients. + */ + std::vector CreateDCTMatrix(int32_t inputLength, + int32_t coefficientCount) override; + + /** + * @brief Given the low and high Mel values, get the normaliser + * for weights to be applied when populating the filter + * bank. Override for the base class implementation. + * @param[in] leftMel Low Mel frequency value. + * @param[in] rightMel High Mel frequency value. + * @param[in] useHTKMethod bool to signal if HTK method is to be + * used for calculation. + * @return Value to use for normalising. + */ + float GetMelFilterBankNormaliser(const float& leftMel, + const float& rightMel, + bool useHTKMethod) override; + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_WAV2LETTER_MFCC_HPP */ \ No newline at end of file diff --git a/source/application/api/use_case/asr/include/Wav2LetterModel.hpp b/source/application/api/use_case/asr/include/Wav2LetterModel.hpp new file mode 100644 index 0000000..a02eed1 --- /dev/null +++ b/source/application/api/use_case/asr/include/Wav2LetterModel.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_WAV2LETTER_MODEL_HPP +#define ASR_WAV2LETTER_MODEL_HPP + +#include "Model.hpp" + +namespace arm { +namespace app { +namespace asr { + extern const int g_FrameLength; + extern const int g_FrameStride; + extern const float g_ScoreThreshold; + extern const int g_ctxLen; +} /* namespace asr */ +} /* namespace app */ +} /* namespace arm */ + +namespace arm { +namespace app { + + class Wav2LetterModel : public Model { + + public: + /* Indices for the expected model - based on input and output tensor shapes */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_outputRowsIdx = 2; + static constexpr uint32_t ms_outputColsIdx = 3; + + /* Model specific constants. */ + static constexpr uint32_t ms_blankTokenIdx = 28; + static constexpr uint32_t ms_numMfccFeatures = 13; + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int ms_maxOpCnt = 5; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_WAV2LETTER_MODEL_HPP */ diff --git a/source/application/api/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/application/api/use_case/asr/include/Wav2LetterPostprocess.hpp new file mode 100644 index 0000000..02738bc --- /dev/null +++ b/source/application/api/use_case/asr/include/Wav2LetterPostprocess.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_WAV2LETTER_POSTPROCESS_HPP +#define ASR_WAV2LETTER_POSTPROCESS_HPP + +#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "AsrClassifier.hpp" +#include "AsrResult.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + /** + * @brief Helper class to manage tensor post-processing for "wav2letter" + * output. + */ + class AsrPostProcess : public BasePostProcess { + public: + bool m_lastIteration = false; /* Flag to set if processing the last set of data for a clip. */ + + /** + * @brief Constructor + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in/out] result Vector of classification results to store decoded outputs. + * @param[in] outputContextLen Left/right context length for output tensor. + * @param[in] blankTokenIdx Index in the labels that the "Blank token" takes. + * @param[in] reductionAxis The axis that the logits of each time step is on. + **/ + AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector& labels, asr::ResultVec& result, + uint32_t outputContextLen, + uint32_t blankTokenIdx, uint32_t reductionAxis); + + /** + * @brief Should perform post-processing of the result of inference then + * populate ASR result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + + /** @brief Gets the output inner length for post-processing. */ + static uint32_t GetOutputInnerLen(const TfLiteTensor*, uint32_t outputCtxLen); + + /** @brief Gets the output context length (left/right) for post-processing. */ + static uint32_t GetOutputContextLen(const Model& model, uint32_t inputCtxLen); + + /** @brief Gets the number of feature vectors to be computed. */ + static uint32_t GetNumFeatureVectors(const Model& model); + + private: + AsrClassifier& m_classifier; /* ASR Classifier object. */ + TfLiteTensor* m_outputTensor; /* Model output tensor. */ + const std::vector& m_labels; /* ASR Labels. */ + asr::ResultVec & m_results; /* Results vector for a single inference. */ + uint32_t m_outputContextLen; /* lengths of left/right contexts for output. */ + uint32_t m_outputInnerLen; /* Length of output inner context. */ + uint32_t m_totalLen; /* Total length of the required axis. */ + uint32_t m_countIterations; /* Current number of iterations. */ + uint32_t m_blankTokenIdx; /* Index of the labels blank token. */ + uint32_t m_reductionAxisIdx; /* Axis containing output logits for a single step. */ + + /** + * @brief Checks if the tensor and axis index are valid + * inputs to the object - based on how it has been initialised. + * @return true if valid, false otherwise. + */ + bool IsInputValid(TfLiteTensor* tensor, + uint32_t axisIdx) const; + + /** + * @brief Gets the tensor data element size in bytes based + * on the tensor type. + * @return Size in bytes, 0 if not supported. + */ + static uint32_t GetTensorElementSize(TfLiteTensor* tensor); + + /** + * @brief Erases sections from the data assuming row-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. + */ + bool EraseSectionsRowWise(uint8_t* ptrData, + uint32_t strideSzBytes, + bool lastIteration); + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_WAV2LETTER_POSTPROCESS_HPP */ \ No newline at end of file diff --git a/source/application/api/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/application/api/use_case/asr/include/Wav2LetterPreprocess.hpp new file mode 100644 index 0000000..9943946 --- /dev/null +++ b/source/application/api/use_case/asr/include/Wav2LetterPreprocess.hpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_WAV2LETTER_PREPROCESS_HPP +#define ASR_WAV2LETTER_PREPROCESS_HPP + +#include "TensorFlowLiteMicro.hpp" +#include "Wav2LetterMfcc.hpp" +#include "AudioUtils.hpp" +#include "DataStructures.hpp" +#include "BaseProcessing.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + /* Class to facilitate pre-processing calculation for Wav2Letter model + * for ASR. */ + using AudioWindow = audio::SlidingWindow; + + class AsrPreProcess : public BasePreProcess { + public: + /** + * @brief Constructor. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated + * for an inference. + * @param[in] mfccWindowLen Number of audio elements to calculate MFCC features per window. + * @param[in] mfccWindowStride Stride (in number of elements) for moving the MFCC window. + */ + AsrPreProcess(TfLiteTensor* inputTensor, + uint32_t numMfccFeatures, + uint32_t numFeatureFrames, + uint32_t mfccWindowLen, + uint32_t mfccWindowStride); + + /** + * @brief Calculates the features required from audio data. This + * includes MFCC, first and second order deltas, + * normalisation and finally, quantisation. The tensor is + * populated with features from a given window placed along + * in a single row. + * @param[in] audioData Pointer to the first element of audio data. + * @param[in] audioDataLen Number of elements in the audio data. + * @return true if successful, false in case of error. + */ + bool DoPreProcess(const void* audioData, size_t audioDataLen) override; + + protected: + /** + * @brief Computes the first and second order deltas for the + * MFCC buffers - they are assumed to be populated. + * + * @param[in] mfcc MFCC buffers. + * @param[out] delta1 Result of the first diff computation. + * @param[out] delta2 Result of the second diff computation. + * @return true if successful, false otherwise. + */ + static bool ComputeDeltas(Array2d& mfcc, + Array2d& delta1, + Array2d& delta2); + + /** + * @brief Given a 2D vector of floats, rescale it to have mean of 0 and + * standard deviation of 1. + * @param[in,out] vec Vector of vector of floats. + */ + static void StandardizeVecF32(Array2d& vec); + + /** + * @brief Standardizes all the MFCC and delta buffers to have mean 0 and std. dev 1. + */ + void Standarize(); + + /** + * @brief Given the quantisation and data type limits, computes + * the quantised values of a floating point input data. + * @param[in] elem Element to be quantised. + * @param[in] quantScale Scale. + * @param[in] quantOffset Offset. + * @param[in] minVal Numerical limit - minimum. + * @param[in] maxVal Numerical limit - maximum. + * @return Floating point quantised value. + */ + static float GetQuantElem( + float elem, + float quantScale, + int quantOffset, + float minVal, + float maxVal); + + /** + * @brief Quantises the MFCC and delta buffers, and places them + * in the output buffer. While doing so, it transposes + * the data. Reason: Buffers in this class are arranged + * for "time" axis to be row major. Primary reason for + * this being the convolution speed up (as we can use + * contiguous memory). The output, however, requires the + * time axis to be in column major arrangement. + * @param[in] outputBuf Pointer to the output buffer. + * @param[in] outputBufSz Output buffer's size. + * @param[in] quantScale Quantisation scale. + * @param[in] quantOffset Quantisation offset. + */ + template + bool Quantise( + T* outputBuf, + const uint32_t outputBufSz, + const float quantScale, + const int quantOffset) + { + /* Check the output size will fit everything. */ + if (outputBufSz < (this->m_mfccBuf.size(0) * 3 * sizeof(T))) { + printf_err("Tensor size too small for features\n"); + return false; + } + + /* Populate. */ + T* outputBufMfcc = outputBuf; + T* outputBufD1 = outputBuf + this->m_numMfccFeats; + T* outputBufD2 = outputBufD1 + this->m_numMfccFeats; + const uint32_t ptrIncr = this->m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ + + const float minVal = std::numeric_limits::min(); + const float maxVal = std::numeric_limits::max(); + + /* Need to transpose while copying and concatenating the tensor. */ + for (uint32_t j = 0; j < this->m_numFeatureFrames; ++j) { + for (uint32_t i = 0; i < this->m_numMfccFeats; ++i) { + *outputBufMfcc++ = static_cast(AsrPreProcess::GetQuantElem( + this->m_mfccBuf(i, j), quantScale, + quantOffset, minVal, maxVal)); + *outputBufD1++ = static_cast(AsrPreProcess::GetQuantElem( + this->m_delta1Buf(i, j), quantScale, + quantOffset, minVal, maxVal)); + *outputBufD2++ = static_cast(AsrPreProcess::GetQuantElem( + this->m_delta2Buf(i, j), quantScale, + quantOffset, minVal, maxVal)); + } + outputBufMfcc += ptrIncr; + outputBufD1 += ptrIncr; + outputBufD2 += ptrIncr; + } + + return true; + } + + private: + audio::Wav2LetterMFCC m_mfcc; /* MFCC instance. */ + TfLiteTensor* m_inputTensor; /* Model input tensor. */ + + /* Actual buffers to be populated. */ + Array2d m_mfccBuf; /* Contiguous buffer 1D: MFCC */ + Array2d m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ + Array2d m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ + + uint32_t m_mfccWindowLen; /* Window length for MFCC. */ + uint32_t m_mfccWindowStride; /* Window stride len for MFCC. */ + uint32_t m_numMfccFeats; /* Number of MFCC features per window. */ + uint32_t m_numFeatureFrames; /* How many sets of m_numMfccFeats. */ + AudioWindow m_mfccSlidingWindow; /* Sliding window to calculate MFCCs. */ + + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_WAV2LETTER_PREPROCESS_HPP */ \ No newline at end of file diff --git a/source/application/api/use_case/asr/src/AsrClassifier.cc b/source/application/api/use_case/asr/src/AsrClassifier.cc new file mode 100644 index 0000000..4ba8c7b --- /dev/null +++ b/source/application/api/use_case/asr/src/AsrClassifier.cc @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AsrClassifier.hpp" + +#include "log_macros.h" +#include "TensorFlowLiteMicro.hpp" +#include "Wav2LetterModel.hpp" + +namespace arm { +namespace app { + + template + bool AsrClassifier::GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint) + { + const uint32_t nElems = tensor->dims->data[Wav2LetterModel::ms_outputRowsIdx]; + const uint32_t nLetters = tensor->dims->data[Wav2LetterModel::ms_outputColsIdx]; + + if (nLetters != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } + + /* NOTE: tensor's size verification against labels should be + * checked by the calling/public function. */ + if (nLetters < 1) { + return false; + } + + /* Final results' container. */ + vecResults = std::vector(nElems); + + T* tensorData = tflite::GetTensorData(tensor); + + /* Get the top 1 results. */ + for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) { + std::pair top_1 = std::make_pair(tensorData[row + 0], 0); + + for (uint32_t j = 1; j < nLetters; ++j) { + if (top_1.first < tensorData[row + j]) { + top_1.first = tensorData[row + j]; + top_1.second = j; + } + } + + double score = static_cast (top_1.first); + vecResults[i].m_normalisedVal = scale * (score - zeroPoint); + vecResults[i].m_label = labels[top_1.second]; + vecResults[i].m_labelIdx = top_1.second; + } + + return true; + } + template bool AsrClassifier::GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, + double scale, double zeroPoint); + template bool AsrClassifier::GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, + double scale, double zeroPoint); + + bool AsrClassifier::GetClassificationResults( + TfLiteTensor* outputTensor, + std::vector& vecResults, + const std::vector & labels, uint32_t topNCount, bool use_softmax) + { + UNUSED(use_softmax); + vecResults.clear(); + + constexpr int minTensorDims = static_cast( + (Wav2LetterModel::ms_outputRowsIdx > Wav2LetterModel::ms_outputColsIdx)? + Wav2LetterModel::ms_outputRowsIdx : Wav2LetterModel::ms_outputColsIdx); + + constexpr uint32_t outColsIdx = Wav2LetterModel::ms_outputColsIdx; + + /* Sanity checks. */ + if (outputTensor == nullptr) { + printf_err("Output vector is null pointer.\n"); + return false; + } else if (outputTensor->dims->size < minTensorDims) { + printf_err("Output tensor expected to be %dD\n", minTensorDims); + return false; + } else if (static_cast(outputTensor->dims->data[outColsIdx]) < topNCount) { + printf_err("Output vectors are smaller than %" PRIu32 "\n", topNCount); + return false; + } else if (static_cast(outputTensor->dims->data[outColsIdx]) != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } + + if (topNCount != 1) { + warn("TopNCount value ignored in this implementation\n"); + } + + /* To return the floating point values, we need quantization parameters. */ + QuantParams quantParams = GetTensorQuantParams(outputTensor); + + bool resultState; + + switch (outputTensor->type) { + case kTfLiteUInt8: + resultState = this->GetTopResults( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + case kTfLiteInt8: + resultState = this->GetTopResults( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + default: + printf_err("Tensor type %s not supported by classifier\n", + TfLiteTypeGetName(outputTensor->type)); + return false; + } + + if (!resultState) { + printf_err("Failed to get sorted set\n"); + return false; + } + + return true; + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/application/api/use_case/asr/src/OutputDecode.cc b/source/application/api/use_case/asr/src/OutputDecode.cc new file mode 100644 index 0000000..41fbe07 --- /dev/null +++ b/source/application/api/use_case/asr/src/OutputDecode.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "OutputDecode.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + std::string DecodeOutput(const std::vector& vecResults) + { + std::string CleanOutputBuffer; + + for (size_t i = 0; i < vecResults.size(); ++i) /* For all elements in vector. */ + { + while (i+1 < vecResults.size() && + vecResults[i].m_label == vecResults[i+1].m_label) /* While the current element is equal to the next, ignore it and move on. */ + { + ++i; + } + if (vecResults[i].m_label != "$") /* $ is a character used to represent unknown and double characters so should not be in output. */ + { + CleanOutputBuffer += vecResults[i].m_label; /* If the element is different to the next, it will be appended to CleanOutputBuffer. */ + } + } + + return CleanOutputBuffer; /* Return string type containing clean output. */ + } + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/application/api/use_case/asr/src/Wav2LetterMfcc.cc b/source/application/api/use_case/asr/src/Wav2LetterMfcc.cc new file mode 100644 index 0000000..bb29b0f --- /dev/null +++ b/source/application/api/use_case/asr/src/Wav2LetterMfcc.cc @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterMfcc.hpp" + +#include "PlatformMath.hpp" +#include "log_macros.h" + +#include + +namespace arm { +namespace app { +namespace audio { + + bool Wav2LetterMFCC::ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) + { + const size_t numBanks = melEnergies.size(); + + if (numBanks != filterBankFilterFirst.size() || + numBanks != filterBankFilterLast.size()) { + printf_err("Unexpected filter bank lengths\n"); + return false; + } + + for (size_t bin = 0; bin < numBanks; ++bin) { + auto filterBankIter = melFilterBank[bin].begin(); + auto end = melFilterBank[bin].end(); + /* Avoid log of zero at later stages, same value used in librosa. + * The number was used during our default wav2letter model training. */ + float melEnergy = 1e-10; + const uint32_t firstIndex = filterBankFilterFirst[bin]; + const uint32_t lastIndex = std::min(filterBankFilterLast[bin], fftVec.size() - 1); + + for (uint32_t i = firstIndex; i <= lastIndex && filterBankIter != end; ++i) { + melEnergy += (*filterBankIter++ * fftVec[i]); + } + + melEnergies[bin] = melEnergy; + } + + return true; + } + + void Wav2LetterMFCC::ConvertToLogarithmicScale( + std::vector& melEnergies) + { + float maxMelEnergy = -FLT_MAX; + + /* Container for natural logarithms of mel energies. */ + std::vector vecLogEnergies(melEnergies.size(), 0.f); + + /* Because we are taking natural logs, we need to multiply by log10(e). + * Also, for wav2letter model, we scale our log10 values by 10. */ + constexpr float multiplier = 10.0 * /* Default scalar. */ + 0.4342944819032518; /* log10f(std::exp(1.0)) */ + + /* Take log of the whole vector. */ + math::MathUtils::VecLogarithmF32(melEnergies, vecLogEnergies); + + /* Scale the log values and get the max. */ + for (auto iterM = melEnergies.begin(), iterL = vecLogEnergies.begin(); + iterM != melEnergies.end() && iterL != vecLogEnergies.end(); ++iterM, ++iterL) { + + *iterM = *iterL * multiplier; + + /* Save the max mel energy. */ + if (*iterM > maxMelEnergy) { + maxMelEnergy = *iterM; + } + } + + /* Clamp the mel energies. */ + constexpr float maxDb = 80.0; + const float clampLevelLowdB = maxMelEnergy - maxDb; + for (float& melEnergy : melEnergies) { + melEnergy = std::max(melEnergy, clampLevelLowdB); + } + } + + std::vector Wav2LetterMFCC::CreateDCTMatrix( + const int32_t inputLength, + const int32_t coefficientCount) + { + std::vector dctMatix(inputLength * coefficientCount); + + /* Orthonormal normalization. */ + const float normalizerK0 = 2 * math::MathUtils::SqrtF32(1.0f / + static_cast(4*inputLength)); + const float normalizer = 2 * math::MathUtils::SqrtF32(1.0f / + static_cast(2*inputLength)); + + const float angleIncr = M_PI / inputLength; + float angle = angleIncr; /* We start using it at k = 1 loop. */ + + /* First row of DCT will use normalizer K0. */ + for (int32_t n = 0; n < inputLength; ++n) { + dctMatix[n] = normalizerK0 /* cos(0) = 1 */; + } + + /* Second row (index = 1) onwards, we use standard normalizer. */ + for (int32_t k = 1, m = inputLength; k < coefficientCount; ++k, m += inputLength) { + for (int32_t n = 0; n < inputLength; ++n) { + dctMatix[m+n] = normalizer * + math::MathUtils::CosineF32((n + 0.5f) * angle); + } + angle += angleIncr; + } + return dctMatix; + } + + float Wav2LetterMFCC::GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + const bool useHTKMethod) + { + /* Slaney normalization for mel weights. */ + return (2.0f / (MFCC::InverseMelScale(rightMel, useHTKMethod) - + MFCC::InverseMelScale(leftMel, useHTKMethod))); + } + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/application/api/use_case/asr/src/Wav2LetterModel.cc b/source/application/api/use_case/asr/src/Wav2LetterModel.cc new file mode 100644 index 0000000..7b1e521 --- /dev/null +++ b/source/application/api/use_case/asr/src/Wav2LetterModel.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterModel.hpp" + +#include "log_macros.h" + + +const tflite::MicroOpResolver& arm::app::Wav2LetterModel::GetOpResolver() +{ + return this->m_opResolver; +} + +bool arm::app::Wav2LetterModel::EnlistOperations() +{ + this->m_opResolver.AddConv2D(); + this->m_opResolver.AddReshape(); + this->m_opResolver.AddLeakyRelu(); + this->m_opResolver.AddSoftmax(); + + if (kTfLiteOk == this->m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } + return true; +} diff --git a/source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc b/source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc new file mode 100644 index 0000000..00e689b --- /dev/null +++ b/source/application/api/use_case/asr/src/Wav2LetterPostprocess.cc @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterPostprocess.hpp" + +#include "Wav2LetterModel.hpp" +#include "log_macros.h" + +#include + +namespace arm { +namespace app { + + AsrPostProcess::AsrPostProcess(TfLiteTensor* outputTensor, AsrClassifier& classifier, + const std::vector& labels, std::vector& results, + const uint32_t outputContextLen, + const uint32_t blankTokenIdx, const uint32_t reductionAxisIdx + ): + m_classifier(classifier), + m_outputTensor(outputTensor), + m_labels{labels}, + m_results(results), + m_outputContextLen(outputContextLen), + m_countIterations(0), + m_blankTokenIdx(blankTokenIdx), + m_reductionAxisIdx(reductionAxisIdx) + { + this->m_outputInnerLen = AsrPostProcess::GetOutputInnerLen(this->m_outputTensor, this->m_outputContextLen); + this->m_totalLen = (2 * this->m_outputContextLen + this->m_outputInnerLen); + } + + bool AsrPostProcess::DoPostProcess() + { + /* Basic checks. */ + if (!this->IsInputValid(this->m_outputTensor, this->m_reductionAxisIdx)) { + return false; + } + + /* Irrespective of tensor type, we use unsigned "byte" */ + auto* ptrData = tflite::GetTensorData(this->m_outputTensor); + const uint32_t elemSz = AsrPostProcess::GetTensorElementSize(this->m_outputTensor); + + /* Other sanity checks. */ + if (0 == elemSz) { + printf_err("Tensor type not supported for post processing\n"); + return false; + } else if (elemSz * this->m_totalLen > this->m_outputTensor->bytes) { + printf_err("Insufficient number of tensor bytes\n"); + return false; + } + + /* Which axis do we need to process? */ + switch (this->m_reductionAxisIdx) { + case Wav2LetterModel::ms_outputRowsIdx: + this->EraseSectionsRowWise( + ptrData, elemSz * this->m_outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx], + this->m_lastIteration); + break; + default: + printf_err("Unsupported axis index: %" PRIu32 "\n", this->m_reductionAxisIdx); + return false; + } + this->m_classifier.GetClassificationResults(this->m_outputTensor, + this->m_results, this->m_labels, 1); + + return true; + } + + bool AsrPostProcess::IsInputValid(TfLiteTensor* tensor, const uint32_t axisIdx) const + { + if (nullptr == tensor) { + return false; + } + + if (static_cast(axisIdx) >= tensor->dims->size) { + printf_err("Invalid axis index: %" PRIu32 "; Max: %d\n", + axisIdx, tensor->dims->size); + return false; + } + + if (static_cast(this->m_totalLen) != + tensor->dims->data[axisIdx]) { + printf_err("Unexpected tensor dimension for axis %" PRIu32", got %d.\n", + axisIdx, tensor->dims->data[axisIdx]); + return false; + } + + return true; + } + + uint32_t AsrPostProcess::GetTensorElementSize(TfLiteTensor* tensor) + { + switch(tensor->type) { + case kTfLiteUInt8: + case kTfLiteInt8: + return 1; + case kTfLiteInt16: + return 2; + case kTfLiteInt32: + case kTfLiteFloat32: + return 4; + default: + printf_err("Unsupported tensor type %s\n", + TfLiteTypeGetName(tensor->type)); + } + + return 0; + } + + bool AsrPostProcess::EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) + { + /* In this case, the "zero-ing" is quite simple as the region + * to be zeroed sits in contiguous memory (row-major). */ + const uint32_t eraseLen = strideSzBytes * this->m_outputContextLen; + + /* Erase left context? */ + if (this->m_countIterations > 0) { + /* Set output of each classification window to the blank token. */ + std::memset(ptrData, 0, eraseLen); + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { + ptrData[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; + } + } + + /* Erase right context? */ + if (false == lastIteration) { + uint8_t* rightCtxPtr = ptrData + (strideSzBytes * (this->m_outputContextLen + this->m_outputInnerLen)); + /* Set output of each classification window to the blank token. */ + std::memset(rightCtxPtr, 0, eraseLen); + for (size_t windowIdx = 0; windowIdx < this->m_outputContextLen; windowIdx++) { + rightCtxPtr[windowIdx*strideSzBytes + this->m_blankTokenIdx] = 1; + } + } + + if (lastIteration) { + this->m_countIterations = 0; + } else { + ++this->m_countIterations; + } + + return true; + } + + uint32_t AsrPostProcess::GetNumFeatureVectors(const Model& model) + { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = std::max(inputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx], 0); + if (inputRows == 0) { + printf_err("Error getting number of input rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_inputRowsIdx); + } + return inputRows; + } + + uint32_t AsrPostProcess::GetOutputInnerLen(const TfLiteTensor* outputTensor, const uint32_t outputCtxLen) + { + const uint32_t outputRows = std::max(outputTensor->dims->data[Wav2LetterModel::ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + } + + /* Watching for underflow. */ + int innerLen = (outputRows - (2 * outputCtxLen)); + + return std::max(innerLen, 0); + } + + uint32_t AsrPostProcess::GetOutputContextLen(const Model& model, const uint32_t inputCtxLen) + { + const uint32_t inputRows = AsrPostProcess::GetNumFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above + * context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %" PRIu32 "\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + if (outputRows == 0) { + printf_err("Error getting number of output rows for axis: %" PRIu32 "\n", + Wav2LetterModel::ms_outputRowsIdx); + return 0; + } + + const float inOutRowRatio = static_cast(inputRows) / + static_cast(outputRows); + + return std::round(static_cast(inputCtxLen) / inOutRowRatio); + } + +} /* namespace app */ +} /* namespace arm */ diff --git a/source/application/api/use_case/asr/src/Wav2LetterPreprocess.cc b/source/application/api/use_case/asr/src/Wav2LetterPreprocess.cc new file mode 100644 index 0000000..92b0631 --- /dev/null +++ b/source/application/api/use_case/asr/src/Wav2LetterPreprocess.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterPreprocess.hpp" + +#include "PlatformMath.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include +#include + +namespace arm { +namespace app { + + AsrPreProcess::AsrPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, + const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, + const uint32_t mfccWindowStride + ): + m_mfcc(numMfccFeatures, mfccWindowLen), + m_inputTensor(inputTensor), + m_mfccBuf(numMfccFeatures, numFeatureFrames), + m_delta1Buf(numMfccFeatures, numFeatureFrames), + m_delta2Buf(numMfccFeatures, numFeatureFrames), + m_mfccWindowLen(mfccWindowLen), + m_mfccWindowStride(mfccWindowStride), + m_numMfccFeats(numMfccFeatures), + m_numFeatureFrames(numFeatureFrames) + { + if (numMfccFeatures > 0 && mfccWindowLen > 0) { + this->m_mfcc.Init(); + } + } + + bool AsrPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) + { + this->m_mfccSlidingWindow = audio::SlidingWindow( + static_cast(audioData), audioDataLen, + this->m_mfccWindowLen, this->m_mfccWindowStride); + + uint32_t mfccBufIdx = 0; + + std::fill(m_mfccBuf.begin(), m_mfccBuf.end(), 0.f); + std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f); + std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f); + + /* While we can slide over the audio. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); + auto mfccAudioData = std::vector( + mfccWindow, + mfccWindow + this->m_mfccWindowLen); + auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData); + for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) { + this->m_mfccBuf(i, mfccBufIdx) = mfcc[i]; + } + ++mfccBufIdx; + } + + /* Pad MFCC if needed by adding MFCC for zeros. */ + if (mfccBufIdx != this->m_numFeatureFrames) { + std::vector zerosWindow = std::vector(this->m_mfccWindowLen, 0); + std::vector mfccZeros = this->m_mfcc.MfccCompute(zerosWindow); + + while (mfccBufIdx != this->m_numFeatureFrames) { + memcpy(&this->m_mfccBuf(0, mfccBufIdx), + mfccZeros.data(), sizeof(float) * m_numMfccFeats); + ++mfccBufIdx; + } + } + + /* Compute first and second order deltas from MFCCs. */ + AsrPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); + + /* Standardize calculated features. */ + this->Standarize(); + + /* Quantise. */ + QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor); + + if (0 == quantParams.scale) { + printf_err("Quantisation scale can't be 0\n"); + return false; + } + + switch(this->m_inputTensor->type) { + case kTfLiteUInt8: + return this->Quantise( + tflite::GetTensorData(this->m_inputTensor), this->m_inputTensor->bytes, + quantParams.scale, quantParams.offset); + case kTfLiteInt8: + return this->Quantise( + tflite::GetTensorData(this->m_inputTensor), this->m_inputTensor->bytes, + quantParams.scale, quantParams.offset); + default: + printf_err("Unsupported tensor type %s\n", + TfLiteTypeGetName(this->m_inputTensor->type)); + } + + return false; + } + + bool AsrPreProcess::ComputeDeltas(Array2d& mfcc, + Array2d& delta1, + Array2d& delta2) + { + const std::vector delta1Coeffs = + {6.66666667e-02, 5.00000000e-02, 3.33333333e-02, + 1.66666667e-02, -3.46944695e-18, -1.66666667e-02, + -3.33333333e-02, -5.00000000e-02, -6.66666667e-02}; + + const std::vector delta2Coeffs = + {0.06060606, 0.01515152, -0.01731602, + -0.03679654, -0.04329004, -0.03679654, + -0.01731602, 0.01515152, 0.06060606}; + + if (delta1.size(0) == 0 || delta2.size(0) != delta1.size(0) || + mfcc.size(0) == 0 || mfcc.size(1) == 0) { + return false; + } + + /* Get the middle index; coeff vec len should always be odd. */ + const size_t coeffLen = delta1Coeffs.size(); + const size_t fMidIdx = (coeffLen - 1)/2; + const size_t numFeatures = mfcc.size(0); + const size_t numFeatVectors = mfcc.size(1); + + /* Iterate through features in MFCC vector. */ + for (size_t i = 0; i < numFeatures; ++i) { + /* For each feature, iterate through time (t) samples representing feature evolution and + * calculate d/dt and d^2/dt^2, using 1D convolution with differential kernels. + * Convolution padding = valid, result size is `time length - kernel length + 1`. + * The result is padded with 0 from both sides to match the size of initial time samples data. + * + * For the small filter, conv1D implementation as a simple loop is efficient enough. + * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32. + */ + + for (size_t j = fMidIdx; j < numFeatVectors - fMidIdx; ++j) { + float d1 = 0; + float d2 = 0; + const size_t mfccStIdx = j - fMidIdx; + + for (size_t k = 0, m = coeffLen - 1; k < coeffLen; ++k, --m) { + + d1 += mfcc(i,mfccStIdx + k) * delta1Coeffs[m]; + d2 += mfcc(i,mfccStIdx + k) * delta2Coeffs[m]; + } + + delta1(i,j) = d1; + delta2(i,j) = d2; + } + } + + return true; + } + + void AsrPreProcess::StandardizeVecF32(Array2d& vec) + { + auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); + + debug("Mean: %f, Stddev: %f\n", mean, stddev); + if (stddev == 0) { + std::fill(vec.begin(), vec.end(), 0); + } else { + const float stddevInv = 1.f/stddev; + const float normalisedMean = mean/stddev; + + auto NormalisingFunction = [=](float& value) { + value = value * stddevInv - normalisedMean; + }; + std::for_each(vec.begin(), vec.end(), NormalisingFunction); + } + } + + void AsrPreProcess::Standarize() + { + AsrPreProcess::StandardizeVecF32(this->m_mfccBuf); + AsrPreProcess::StandardizeVecF32(this->m_delta1Buf); + AsrPreProcess::StandardizeVecF32(this->m_delta2Buf); + } + + float AsrPreProcess::GetQuantElem( + const float elem, + const float quantScale, + const int quantOffset, + const float minVal, + const float maxVal) + { + float val = std::round((elem/quantScale) + quantOffset); + return std::min(std::max(val, minVal), maxVal); + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file -- cgit v1.2.1