summaryrefslogtreecommitdiff
path: root/tests/use_case/kws/KWSHandlerTest.cc
diff options
context:
space:
mode:
authorKshitij Sisodia <kshitij.sisodia@arm.com>2021-12-24 11:05:11 +0000
committerLiam Barry <liam.barry@arm.com>2021-12-24 14:20:36 +0000
commit76a1580861210e0310db23acbc29e1064ae30ead (patch)
treef947145cffd944aa3724c90745fc0e9d8e2fb2f4 /tests/use_case/kws/KWSHandlerTest.cc
parent871fcdc755173b9f7ecb8cf9dc8dc6306329958c (diff)
downloadml-embedded-evaluation-kit-76a1580861210e0310db23acbc29e1064ae30ead.tar.gz
MLECO-2599: Replace DSCNN with MicroNet for KWS
Added SoftMax function to Mathutils to allow MicroNet to output probability as it does not nativelu have this layer. Minor refactoring to accommodate Softmax Calculations Extensive renaming and updating of documentation and resource download script. Added SoftMax function to Mathutils to allow MicroNet to output probability. Change-Id: I7cbbda1024d14b85c9ac1beea7ca8fbffd0b6eb5 Signed-off-by: Liam Barry <liam.barry@arm.com>
Diffstat (limited to 'tests/use_case/kws/KWSHandlerTest.cc')
-rw-r--r--tests/use_case/kws/KWSHandlerTest.cc30
1 files changed, 15 insertions, 15 deletions
diff --git a/tests/use_case/kws/KWSHandlerTest.cc b/tests/use_case/kws/KWSHandlerTest.cc
index 50e5a83..a7e75fb 100644
--- a/tests/use_case/kws/KWSHandlerTest.cc
+++ b/tests/use_case/kws/KWSHandlerTest.cc
@@ -15,7 +15,7 @@
* limitations under the License.
*/
#include <catch.hpp>
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "hal.h"
#include "KwsResult.hpp"
@@ -27,7 +27,7 @@
TEST_CASE("Model info")
{
/* Model wrapper object. */
- arm::app::DsCnnModel model;
+ arm::app::MicroNetKwsModel model;
/* Load the model. */
REQUIRE(model.Init());
@@ -53,7 +53,7 @@ TEST_CASE("Inference by index")
hal_platform_init(&platform);
/* Model wrapper object. */
- arm::app::DsCnnModel model;
+ arm::app::MicroNetKwsModel model;
/* Load the model. */
REQUIRE(model.Init());
@@ -65,8 +65,8 @@ TEST_CASE("Inference by index")
caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
- caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */
- caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */
+ caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNetKws. */
+ caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNetKws. */
caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */
arm::app::Classifier classifier; /* classifier wrapper object. */
@@ -97,25 +97,25 @@ TEST_CASE("Inference by index")
SECTION("Index = 0, short clip down")
{
/* Result: down. */
- checker(0, {5});
+ checker(0, {0});
}
SECTION("Index = 1, long clip right->left->up")
{
/* Result: right->right->left->up->up. */
- checker(1, {7, 1, 6, 4, 4});
+ checker(1, {6, 6, 2, 8, 8});
}
SECTION("Index = 2, short clip yes")
{
/* Result: yes. */
- checker(2, {2});
+ checker(2, {9});
}
SECTION("Index = 3, long clip yes->no->go->stop")
{
- /* Result: yes->go->no->go->go->go->stop. */
- checker(3, {2, 11, 3, 11, 11, 11, 10});
+ /* Result: yes->no->no->go->go->stop->stop. */
+ checker(3, {9, 3, 3, 1, 1, 7, 7});
}
}
@@ -132,7 +132,7 @@ TEST_CASE("Inference run all clips")
hal_platform_init(&platform);
/* Model wrapper object. */
- arm::app::DsCnnModel model;
+ arm::app::MicroNetKwsModel model;
/* Load the model. */
REQUIRE(model.Init());
@@ -145,9 +145,9 @@ TEST_CASE("Inference run all clips")
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
- caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */
- caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */
- caseContext.Set<float>("scoreThreshold", 0.9); /* Normalised score threshold. */
+ caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNet. */
+ caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNet. */
+ caseContext.Set<float>("scoreThreshold", 0.7); /* Normalised score threshold. */
arm::app::Classifier classifier; /* classifier wrapper object. */
caseContext.Set<arm::app::Classifier&>("classifier", classifier);
@@ -170,7 +170,7 @@ TEST_CASE("List all audio clips")
hal_platform_init(&platform);
/* Model wrapper object. */
- arm::app::DsCnnModel model;
+ arm::app::MicroNetKwsModel model;
/* Load the model. */
REQUIRE(model.Init());