diff options
Diffstat (limited to 'tests/use_case/kws/KWSHandlerTest.cc')
-rw-r--r-- | tests/use_case/kws/KWSHandlerTest.cc | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/tests/use_case/kws/KWSHandlerTest.cc b/tests/use_case/kws/KWSHandlerTest.cc index 50e5a83..a7e75fb 100644 --- a/tests/use_case/kws/KWSHandlerTest.cc +++ b/tests/use_case/kws/KWSHandlerTest.cc @@ -15,7 +15,7 @@ * limitations under the License. */ #include <catch.hpp> -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" #include "KwsResult.hpp" @@ -27,7 +27,7 @@ TEST_CASE("Model info") { /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -53,7 +53,7 @@ TEST_CASE("Inference by index") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -65,8 +65,8 @@ TEST_CASE("Inference by index") caseContext.Set<arm::app::Profiler&>("profiler", profiler); caseContext.Set<hal_platform&>("platform", platform); caseContext.Set<arm::app::Model&>("model", model); - caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */ - caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */ + caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNetKws. */ + caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNetKws. */ caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */ arm::app::Classifier classifier; /* classifier wrapper object. */ @@ -97,25 +97,25 @@ TEST_CASE("Inference by index") SECTION("Index = 0, short clip down") { /* Result: down. */ - checker(0, {5}); + checker(0, {0}); } SECTION("Index = 1, long clip right->left->up") { /* Result: right->right->left->up->up. */ - checker(1, {7, 1, 6, 4, 4}); + checker(1, {6, 6, 2, 8, 8}); } SECTION("Index = 2, short clip yes") { /* Result: yes. */ - checker(2, {2}); + checker(2, {9}); } SECTION("Index = 3, long clip yes->no->go->stop") { - /* Result: yes->go->no->go->go->go->stop. */ - checker(3, {2, 11, 3, 11, 11, 11, 10}); + /* Result: yes->no->no->go->go->stop->stop. */ + checker(3, {9, 3, 3, 1, 1, 7, 7}); } } @@ -132,7 +132,7 @@ TEST_CASE("Inference run all clips") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -145,9 +145,9 @@ TEST_CASE("Inference run all clips") caseContext.Set<hal_platform&>("platform", platform); caseContext.Set<arm::app::Model&>("model", model); caseContext.Set<uint32_t>("clipIndex", 0); - caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */ - caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */ - caseContext.Set<float>("scoreThreshold", 0.9); /* Normalised score threshold. */ + caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNet. */ + caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNet. */ + caseContext.Set<float>("scoreThreshold", 0.7); /* Normalised score threshold. */ arm::app::Classifier classifier; /* classifier wrapper object. */ caseContext.Set<arm::app::Classifier&>("classifier", classifier); @@ -170,7 +170,7 @@ TEST_CASE("List all audio clips") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); |