diff options
Diffstat (limited to 'source/use_case/asr/src/Wav2LetterPreprocess.cc')
-rw-r--r-- | source/use_case/asr/src/Wav2LetterPreprocess.cc | 106 |
1 files changed, 43 insertions, 63 deletions
diff --git a/source/use_case/asr/src/Wav2LetterPreprocess.cc b/source/use_case/asr/src/Wav2LetterPreprocess.cc index e5ac3ca..590d08a 100644 --- a/source/use_case/asr/src/Wav2LetterPreprocess.cc +++ b/source/use_case/asr/src/Wav2LetterPreprocess.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,37 +24,31 @@ namespace arm { namespace app { -namespace audio { -namespace asr { - - Preprocess::Preprocess( - const uint32_t numMfccFeatures, - const uint32_t windowLen, - const uint32_t windowStride, - const uint32_t numMfccVectors): - m_mfcc(numMfccFeatures, windowLen), - m_mfccBuf(numMfccFeatures, numMfccVectors), - m_delta1Buf(numMfccFeatures, numMfccVectors), - m_delta2Buf(numMfccFeatures, numMfccVectors), - m_windowLen(windowLen), - m_windowStride(windowStride), + + ASRPreProcess::ASRPreProcess(TfLiteTensor* inputTensor, const uint32_t numMfccFeatures, + const uint32_t numFeatureFrames, const uint32_t mfccWindowLen, + const uint32_t mfccWindowStride + ): + m_mfcc(numMfccFeatures, mfccWindowLen), + m_inputTensor(inputTensor), + m_mfccBuf(numMfccFeatures, numFeatureFrames), + m_delta1Buf(numMfccFeatures, numFeatureFrames), + m_delta2Buf(numMfccFeatures, numFeatureFrames), + m_mfccWindowLen(mfccWindowLen), + m_mfccWindowStride(mfccWindowStride), m_numMfccFeats(numMfccFeatures), - m_numFeatVectors(numMfccVectors), - m_window() + m_numFeatureFrames(numFeatureFrames) { - if (numMfccFeatures > 0 && windowLen > 0) { + if (numMfccFeatures > 0 && mfccWindowLen > 0) { this->m_mfcc.Init(); } } - bool Preprocess::Invoke( - const int16_t* audioData, - const uint32_t audioDataLen, - TfLiteTensor* tensor) + bool ASRPreProcess::DoPreProcess(const void* audioData, const size_t audioDataLen) { - this->m_window = SlidingWindow<const int16_t>( - audioData, audioDataLen, - this->m_windowLen, this->m_windowStride); + this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>( + static_cast<const int16_t*>(audioData), audioDataLen, + this->m_mfccWindowLen, this->m_mfccWindowStride); uint32_t mfccBufIdx = 0; @@ -62,12 +56,12 @@ namespace asr { std::fill(m_delta1Buf.begin(), m_delta1Buf.end(), 0.f); std::fill(m_delta2Buf.begin(), m_delta2Buf.end(), 0.f); - /* While we can slide over the window. */ - while (this->m_window.HasNext()) { - const int16_t* mfccWindow = this->m_window.Next(); + /* While we can slide over the audio. */ + while (this->m_mfccSlidingWindow.HasNext()) { + const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next(); auto mfccAudioData = std::vector<int16_t>( mfccWindow, - mfccWindow + this->m_windowLen); + mfccWindow + this->m_mfccWindowLen); auto mfcc = this->m_mfcc.MfccCompute(mfccAudioData); for (size_t i = 0; i < this->m_mfccBuf.size(0); ++i) { this->m_mfccBuf(i, mfccBufIdx) = mfcc[i]; @@ -76,11 +70,11 @@ namespace asr { } /* Pad MFCC if needed by adding MFCC for zeros. */ - if (mfccBufIdx != this->m_numFeatVectors) { - std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_windowLen, 0); + if (mfccBufIdx != this->m_numFeatureFrames) { + std::vector<int16_t> zerosWindow = std::vector<int16_t>(this->m_mfccWindowLen, 0); std::vector<float> mfccZeros = this->m_mfcc.MfccCompute(zerosWindow); - while (mfccBufIdx != this->m_numFeatVectors) { + while (mfccBufIdx != this->m_numFeatureFrames) { memcpy(&this->m_mfccBuf(0, mfccBufIdx), mfccZeros.data(), sizeof(float) * m_numMfccFeats); ++mfccBufIdx; @@ -88,39 +82,37 @@ namespace asr { } /* Compute first and second order deltas from MFCCs. */ - Preprocess::ComputeDeltas(this->m_mfccBuf, - this->m_delta1Buf, - this->m_delta2Buf); + ASRPreProcess::ComputeDeltas(this->m_mfccBuf, this->m_delta1Buf, this->m_delta2Buf); - /* Normalise. */ - this->Normalise(); + /* Standardize calculated features. */ + this->Standarize(); /* Quantise. */ - QuantParams quantParams = GetTensorQuantParams(tensor); + QuantParams quantParams = GetTensorQuantParams(this->m_inputTensor); if (0 == quantParams.scale) { printf_err("Quantisation scale can't be 0\n"); return false; } - switch(tensor->type) { + switch(this->m_inputTensor->type) { case kTfLiteUInt8: return this->Quantise<uint8_t>( - tflite::GetTensorData<uint8_t>(tensor), tensor->bytes, + tflite::GetTensorData<uint8_t>(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); case kTfLiteInt8: return this->Quantise<int8_t>( - tflite::GetTensorData<int8_t>(tensor), tensor->bytes, + tflite::GetTensorData<int8_t>(this->m_inputTensor), this->m_inputTensor->bytes, quantParams.scale, quantParams.offset); default: printf_err("Unsupported tensor type %s\n", - TfLiteTypeGetName(tensor->type)); + TfLiteTypeGetName(this->m_inputTensor->type)); } return false; } - bool Preprocess::ComputeDeltas(Array2d<float>& mfcc, + bool ASRPreProcess::ComputeDeltas(Array2d<float>& mfcc, Array2d<float>& delta1, Array2d<float>& delta2) { @@ -175,20 +167,10 @@ namespace asr { return true; } - float Preprocess::GetMean(Array2d<float>& vec) - { - return math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); - } - - float Preprocess::GetStdDev(Array2d<float>& vec, const float mean) - { - return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); - } - - void Preprocess::NormaliseVec(Array2d<float>& vec) + void ASRPreProcess::StandardizeVecF32(Array2d<float>& vec) { - auto mean = Preprocess::GetMean(vec); - auto stddev = Preprocess::GetStdDev(vec, mean); + auto mean = math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + auto stddev = math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); debug("Mean: %f, Stddev: %f\n", mean, stddev); if (stddev == 0) { @@ -204,14 +186,14 @@ namespace asr { } } - void Preprocess::Normalise() + void ASRPreProcess::Standarize() { - Preprocess::NormaliseVec(this->m_mfccBuf); - Preprocess::NormaliseVec(this->m_delta1Buf); - Preprocess::NormaliseVec(this->m_delta2Buf); + ASRPreProcess::StandardizeVecF32(this->m_mfccBuf); + ASRPreProcess::StandardizeVecF32(this->m_delta1Buf); + ASRPreProcess::StandardizeVecF32(this->m_delta2Buf); } - float Preprocess::GetQuantElem( + float ASRPreProcess::GetQuantElem( const float elem, const float quantScale, const int quantOffset, @@ -222,7 +204,5 @@ namespace asr { return std::min<float>(std::max<float>(val, minVal), maxVal); } -} /* namespace asr */ -} /* namespace audio */ } /* namespace app */ } /* namespace arm */
\ No newline at end of file |