diff options
Diffstat (limited to 'source/use_case/kws/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/kws/src/UseCaseHandler.cc | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index 3d95753..8085af7 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -18,9 +18,9 @@ #include "InputFiles.hpp" #include "Classifier.hpp" -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include "AudioUtils.hpp" #include "UseCaseCommonUtils.hpp" #include "KwsResult.hpp" @@ -59,7 +59,7 @@ namespace app { * @return Function to be called providing audio sample and sliding window index. */ static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize); @@ -72,8 +72,8 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast<int>( - (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)? - arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx); + (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? + arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); auto& model = ctx.Get<Model&>("model"); @@ -105,10 +105,10 @@ namespace app { } TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx]; - const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx]; + const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; - audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength); + audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength); mfcc.Init(); /* Deduce the data length required for 1 inference from the network parameters. */ @@ -132,7 +132,7 @@ namespace app { /* We expect to be sampling 1 second worth of data at a time. * NOTE: This is only used for time stamp calculation. */ - const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq; + const float secondsPerSample = 1.0/audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; do { platform.data_psn->clear(COLOR_BLACK); @@ -208,7 +208,7 @@ namespace app { std::vector<ClassificationResult> classificationResult; auto& classifier = ctx.Get<KwsClassifier&>("classifier"); classifier.GetClassificationResults(outputTensor, classificationResult, - ctx.Get<std::vector<std::string>&>("labels"), 1); + ctx.Get<std::vector<std::string>&>("labels"), 1, true); results.emplace_back(kws::KwsResult(classificationResult, audioDataSlider.Index() * secondsPerSample * audioDataStride, @@ -240,7 +240,6 @@ namespace app { return true; } - static bool PresentInferenceResult(hal_platform& platform, const std::vector<arm::app::kws::KwsResult>& results) { @@ -259,7 +258,6 @@ namespace app { std::string topKeyword{"<none>"}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { topKeyword = results[i].m_resultVec[0].m_label; score = results[i].m_resultVec[0].m_normalisedVal; @@ -366,7 +364,7 @@ namespace app { static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) { std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; |