summaryrefslogtreecommitdiff
path: root/source/use_case/kws_asr/src/UseCaseHandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/kws_asr/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/kws_asr/src/UseCaseHandler.cc20
1 files changed, 10 insertions, 10 deletions
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc
index 1d88ba1..c67be22 100644
--- a/source/use_case/kws_asr/src/UseCaseHandler.cc
+++ b/source/use_case/kws_asr/src/UseCaseHandler.cc
@@ -20,8 +20,8 @@
#include "InputFiles.hpp"
#include "AudioUtils.hpp"
#include "UseCaseCommonUtils.hpp"
-#include "DsCnnModel.hpp"
-#include "DsCnnMfcc.hpp"
+#include "MicroNetKwsModel.hpp"
+#include "MicroNetKwsMfcc.hpp"
#include "Classifier.hpp"
#include "KwsResult.hpp"
#include "Wav2LetterMfcc.hpp"
@@ -77,12 +77,12 @@ namespace app {
*
* @param[in] mfcc MFCC feature calculator.
* @param[in,out] inputTensor Input tensor pointer to store calculated features.
- * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
+ * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
*
* @return function function to be called providing audio sample and sliding window index.
**/
static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
+ GetFeatureCalculator(audio::MicroNetMFCC& mfcc,
TfLiteTensor* inputTensor,
size_t cacheSize);
@@ -98,8 +98,8 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartY = 40;
constexpr int minTensorDims = static_cast<int>(
- (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
- arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
+ (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
+ arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
KWSOutput output;
@@ -128,7 +128,7 @@ namespace app {
const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
- audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
+ audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength);
kwsMfcc.Init();
/* Deduce the data length required for 1 KWS inference from the network parameters. */
@@ -152,7 +152,7 @@ namespace app {
/* We expect to be sampling 1 second worth of data at a time
* NOTE: This is only used for time stamp calculation. */
- const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
+ const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq;
auto currentIndex = ctx.Get<uint32_t>("clipIndex");
@@ -230,7 +230,7 @@ namespace app {
kwsClassifier.GetClassificationResults(
kwsOutputTensor, kwsClassificationResult,
- ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
+ ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true);
kwsResults.emplace_back(
kws::KwsResult(
@@ -604,7 +604,7 @@ namespace app {
static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
+ GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
{
std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;