summaryrefslogtreecommitdiff
path: root/source/use_case/kws/src/UseCaseHandler.cc
diff options
context:
space:
mode:
authorKshitij Sisodia <kshitij.sisodia@arm.com>2021-12-24 11:05:11 +0000
committerLiam Barry <liam.barry@arm.com>2021-12-24 14:20:36 +0000
commit76a1580861210e0310db23acbc29e1064ae30ead (patch)
treef947145cffd944aa3724c90745fc0e9d8e2fb2f4 /source/use_case/kws/src/UseCaseHandler.cc
parent871fcdc755173b9f7ecb8cf9dc8dc6306329958c (diff)
downloadml-embedded-evaluation-kit-76a1580861210e0310db23acbc29e1064ae30ead.tar.gz
MLECO-2599: Replace DSCNN with MicroNet for KWS
Added SoftMax function to Mathutils to allow MicroNet to output probability as it does not nativelu have this layer. Minor refactoring to accommodate Softmax Calculations Extensive renaming and updating of documentation and resource download script. Added SoftMax function to Mathutils to allow MicroNet to output probability. Change-Id: I7cbbda1024d14b85c9ac1beea7ca8fbffd0b6eb5 Signed-off-by: Liam Barry <liam.barry@arm.com>
Diffstat (limited to 'source/use_case/kws/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/kws/src/UseCaseHandler.cc24
1 files changed, 11 insertions, 13 deletions
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index 3d95753..8085af7 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -18,9 +18,9 @@
#include "InputFiles.hpp"
#include "Classifier.hpp"
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "hal.h"
-#include "DsCnnMfcc.hpp"
+#include "MicroNetKwsMfcc.hpp"
#include "AudioUtils.hpp"
#include "UseCaseCommonUtils.hpp"
#include "KwsResult.hpp"
@@ -59,7 +59,7 @@ namespace app {
* @return Function to be called providing audio sample and sliding window index.
*/
static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
+ GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
TfLiteTensor* inputTensor,
size_t cacheSize);
@@ -72,8 +72,8 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
constexpr int minTensorDims = static_cast<int>(
- (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
- arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
+ (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
+ arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
auto& model = ctx.Get<Model&>("model");
@@ -105,10 +105,10 @@ namespace app {
}
TfLiteIntArray* inputShape = model.GetInputShape(0);
- const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx];
- const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx];
+ const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx];
+ const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
- audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength);
+ audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength);
mfcc.Init();
/* Deduce the data length required for 1 inference from the network parameters. */
@@ -132,7 +132,7 @@ namespace app {
/* We expect to be sampling 1 second worth of data at a time.
* NOTE: This is only used for time stamp calculation. */
- const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
+ const float secondsPerSample = 1.0/audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
do {
platform.data_psn->clear(COLOR_BLACK);
@@ -208,7 +208,7 @@ namespace app {
std::vector<ClassificationResult> classificationResult;
auto& classifier = ctx.Get<KwsClassifier&>("classifier");
classifier.GetClassificationResults(outputTensor, classificationResult,
- ctx.Get<std::vector<std::string>&>("labels"), 1);
+ ctx.Get<std::vector<std::string>&>("labels"), 1, true);
results.emplace_back(kws::KwsResult(classificationResult,
audioDataSlider.Index() * secondsPerSample * audioDataStride,
@@ -240,7 +240,6 @@ namespace app {
return true;
}
-
static bool PresentInferenceResult(hal_platform& platform,
const std::vector<arm::app::kws::KwsResult>& results)
{
@@ -259,7 +258,6 @@ namespace app {
std::string topKeyword{"<none>"};
float score = 0.f;
-
if (!results[i].m_resultVec.empty()) {
topKeyword = results[i].m_resultVec[0].m_label;
score = results[i].m_resultVec[0].m_normalisedVal;
@@ -366,7 +364,7 @@ namespace app {
static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
+ GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
{
std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;