summaryrefslogtreecommitdiff
path: root/source/use_case/kws/src/KwsProcessing.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/kws/src/KwsProcessing.cc')
-rw-r--r--source/use_case/kws/src/KwsProcessing.cc53
1 files changed, 23 insertions, 30 deletions
diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc
index 14f9fce..328709d 100644
--- a/source/use_case/kws/src/KwsProcessing.cc
+++ b/source/use_case/kws/src/KwsProcessing.cc
@@ -22,22 +22,19 @@
namespace arm {
namespace app {
- KWSPreProcess::KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride):
+ KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames,
+ int mfccFrameLength, int mfccFrameStride
+ ):
+ m_inputTensor{inputTensor},
m_mfccFrameLength{mfccFrameLength},
m_mfccFrameStride{mfccFrameStride},
+ m_numMfccFrames{numMfccFrames},
m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)}
{
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
this->m_mfcc.Init();
- TfLiteIntArray* inputShape = model->GetInputShape(0);
- const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
-
/* Deduce the data length required for 1 inference from the network parameters. */
- this->m_audioDataWindowSize = numMfccFrames * this->m_mfccFrameStride +
+ this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride +
(this->m_mfccFrameLength - this->m_mfccFrameStride);
/* Creating an MFCC feature sliding window for the data required for 1 inference. */
@@ -62,7 +59,7 @@ namespace app {
- this->m_numMfccVectorsInAudioStride;
/* Construct feature calculation function. */
- this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_model->GetInputTensor(0),
+ this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor,
this->m_numReusedMfccVectors);
if (!this->m_mfccFeatureCalculator) {
@@ -70,7 +67,7 @@ namespace app {
}
}
- bool KWSPreProcess::DoPreProcess(const void* data, size_t inputSize)
+ bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize)
{
UNUSED(inputSize);
if (data == nullptr) {
@@ -116,8 +113,8 @@ namespace app {
*/
template<class T>
std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
- KWSPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
- std::function<std::vector<T> (std::vector<int16_t>& )> compute)
+ KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+ std::function<std::vector<T> (std::vector<int16_t>& )> compute)
{
/* Feature cache to be captured by lambda function. */
static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
@@ -149,18 +146,18 @@ namespace app {
}
template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- KWSPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
+ KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
- KWSPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<float>(std::vector<int16_t>&)> compute);
+ KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor,
+ size_t cacheSize,
+ std::function<std::vector<float>(std::vector<int16_t>&)> compute);
std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- KWSPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
+ KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
{
std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
@@ -195,23 +192,19 @@ namespace app {
return mfccFeatureCalc;
}
- KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model,
+ KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results)
- :m_kwsClassifier{classifier},
+ :m_outputTensor{outputTensor},
+ m_kwsClassifier{classifier},
m_labels{labels},
m_results{results}
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ {}
- bool KWSPostProcess::DoPostProcess()
+ bool KwsPostProcess::DoPostProcess()
{
return this->m_kwsClassifier.GetClassificationResults(
- this->m_model->GetOutputTensor(0), this->m_results,
+ this->m_outputTensor, this->m_results,
this->m_labels, 1, true);
}