summaryrefslogtreecommitdiff
path: root/source/use_case/kws_asr/src/MainLoop.cc
diff options
context:
space:
mode:
authorKshitij Sisodia <kshitij.sisodia@arm.com>2021-12-24 11:05:11 +0000
committerLiam Barry <liam.barry@arm.com>2021-12-24 14:20:36 +0000
commit76a1580861210e0310db23acbc29e1064ae30ead (patch)
treef947145cffd944aa3724c90745fc0e9d8e2fb2f4 /source/use_case/kws_asr/src/MainLoop.cc
parent871fcdc755173b9f7ecb8cf9dc8dc6306329958c (diff)
downloadml-embedded-evaluation-kit-76a1580861210e0310db23acbc29e1064ae30ead.tar.gz
MLECO-2599: Replace DSCNN with MicroNet for KWS
Added SoftMax function to Mathutils to allow MicroNet to output probability as it does not nativelu have this layer. Minor refactoring to accommodate Softmax Calculations Extensive renaming and updating of documentation and resource download script. Added SoftMax function to Mathutils to allow MicroNet to output probability. Change-Id: I7cbbda1024d14b85c9ac1beea7ca8fbffd0b6eb5 Signed-off-by: Liam Barry <liam.barry@arm.com>
Diffstat (limited to 'source/use_case/kws_asr/src/MainLoop.cc')
-rw-r--r--source/use_case/kws_asr/src/MainLoop.cc10
1 files changed, 5 insertions, 5 deletions
diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc
index d5a2c2b..30cb084 100644
--- a/source/use_case/kws_asr/src/MainLoop.cc
+++ b/source/use_case/kws_asr/src/MainLoop.cc
@@ -16,11 +16,11 @@
*/
#include "hal.h" /* Brings in platform definitions. */
#include "InputFiles.hpp" /* For input images. */
-#include "Labels_dscnn.hpp" /* For DS-CNN label strings. */
+#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */
#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */
#include "Classifier.hpp" /* KWS classifier. */
#include "AsrClassifier.hpp" /* ASR classifier. */
-#include "DsCnnModel.hpp" /* KWS model class for running inference. */
+#include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */
#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
@@ -69,7 +69,7 @@ static uint32_t GetOutputInnerLen(const arm::app::Model& model,
void main_loop(hal_platform& platform)
{
/* Model wrapper objects. */
- arm::app::DsCnnModel kwsModel;
+ arm::app::MicroNetKwsModel kwsModel;
arm::app::Wav2LetterModel asrModel;
/* Load the models. */
@@ -81,7 +81,7 @@ void main_loop(hal_platform& platform)
/* Initialise the asr model using the same allocator from KWS
* to re-use the tensor arena. */
if (!asrModel.Init(kwsModel.GetAllocator())) {
- printf_err("Failed to initalise ASR model\n");
+ printf_err("Failed to initialise ASR model\n");
return;
}
@@ -137,7 +137,7 @@ void main_loop(hal_platform& platform)
caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels);
/* Index of the kws outputs we trigger ASR on. */
- caseContext.Set<uint32_t>("keywordindex", 2);
+ caseContext.Set<uint32_t>("keywordindex", 9 );
/* Loop. */
bool executionSuccessful = true;