diff options
Diffstat (limited to 'source/use_case/kws/src/KwsProcessing.cc')
-rw-r--r-- | source/use_case/kws/src/KwsProcessing.cc | 53 |
1 files changed, 23 insertions, 30 deletions
diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc index 14f9fce..328709d 100644 --- a/source/use_case/kws/src/KwsProcessing.cc +++ b/source/use_case/kws/src/KwsProcessing.cc @@ -22,22 +22,19 @@ namespace arm { namespace app { - KWSPreProcess::KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride): + KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames, + int mfccFrameLength, int mfccFrameStride + ): + m_inputTensor{inputTensor}, m_mfccFrameLength{mfccFrameLength}, m_mfccFrameStride{mfccFrameStride}, + m_numMfccFrames{numMfccFrames}, m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)} { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; this->m_mfcc.Init(); - TfLiteIntArray* inputShape = model->GetInputShape(0); - const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; - /* Deduce the data length required for 1 inference from the network parameters. */ - this->m_audioDataWindowSize = numMfccFrames * this->m_mfccFrameStride + + this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride + (this->m_mfccFrameLength - this->m_mfccFrameStride); /* Creating an MFCC feature sliding window for the data required for 1 inference. */ @@ -62,7 +59,7 @@ namespace app { - this->m_numMfccVectorsInAudioStride; /* Construct feature calculation function. */ - this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_model->GetInputTensor(0), + this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor, this->m_numReusedMfccVectors); if (!this->m_mfccFeatureCalculator) { @@ -70,7 +67,7 @@ namespace app { } } - bool KWSPreProcess::DoPreProcess(const void* data, size_t inputSize) + bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize) { UNUSED(inputSize); if (data == nullptr) { @@ -116,8 +113,8 @@ namespace app { */ template<class T> std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> - KWSPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function<std::vector<T> (std::vector<int16_t>& )> compute) + KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function<std::vector<T> (std::vector<int16_t>& )> compute) { /* Feature cache to be captured by lambda function. */ static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); @@ -149,18 +146,18 @@ namespace app { } template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)> - KWSPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); + KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)> - KWSPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<float>(std::vector<int16_t>&)> compute); + KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<float>(std::vector<int16_t>&)> compute); std::function<void (std::vector<int16_t>&, int, bool, size_t)> - KWSPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) { std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; @@ -195,23 +192,19 @@ namespace app { return mfccFeatureCalc; } - KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model, + KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results) - :m_kwsClassifier{classifier}, + :m_outputTensor{outputTensor}, + m_kwsClassifier{classifier}, m_labels{labels}, m_results{results} - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + {} - bool KWSPostProcess::DoPostProcess() + bool KwsPostProcess::DoPostProcess() { return this->m_kwsClassifier.GetClassificationResults( - this->m_model->GetOutputTensor(0), this->m_results, + this->m_outputTensor, this->m_results, this->m_labels, 1, true); } |