diff options
Diffstat (limited to 'tests/use_case/kws')
-rw-r--r-- | tests/use_case/kws/InferenceTestMicroNetKws.cc (renamed from tests/use_case/kws/InferenceTestDSCNN.cc) | 10 | ||||
-rw-r--r-- | tests/use_case/kws/KWSHandlerTest.cc | 30 | ||||
-rw-r--r-- | tests/use_case/kws/MfccTests.cc | 8 |
3 files changed, 24 insertions, 24 deletions
diff --git a/tests/use_case/kws/InferenceTestDSCNN.cc b/tests/use_case/kws/InferenceTestMicroNetKws.cc index 8918073..e6e7753 100644 --- a/tests/use_case/kws/InferenceTestDSCNN.cc +++ b/tests/use_case/kws/InferenceTestMicroNetKws.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" #include "TestData_kws.hpp" #include "TensorFlowLiteMicro.hpp" @@ -74,9 +74,9 @@ void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app:: } } -TEST_CASE("Running random inference with TensorFlow Lite Micro and DsCnnModel Int8", "[DS_CNN]") +TEST_CASE("Running random inference with TensorFlow Lite Micro and MicroNetKwsModel Int8", "[MicroNetKws]") { - arm::app::DsCnnModel model{}; + arm::app::MicroNetKwsModel model{}; REQUIRE_FALSE(model.IsInited()); REQUIRE(model.Init()); @@ -85,7 +85,7 @@ TEST_CASE("Running random inference with TensorFlow Lite Micro and DsCnnModel In REQUIRE(RunInferenceRandom(model)); } -TEST_CASE("Running inference with TensorFlow Lite Micro and DsCnnModel Uint8", "[DS_CNN]") +TEST_CASE("Running inference with TensorFlow Lite Micro and MicroNetKwsModel int8", "[MicroNetKws]") { REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); for (uint32_t i = 0 ; i < NUMBER_OF_IFM_FILES; ++i) { @@ -94,7 +94,7 @@ TEST_CASE("Running inference with TensorFlow Lite Micro and DsCnnModel Uint8", " DYNAMIC_SECTION("Executing inference with re-init") { - arm::app::DsCnnModel model{}; + arm::app::MicroNetKwsModel model{}; REQUIRE_FALSE(model.IsInited()); REQUIRE(model.Init()); diff --git a/tests/use_case/kws/KWSHandlerTest.cc b/tests/use_case/kws/KWSHandlerTest.cc index 50e5a83..a7e75fb 100644 --- a/tests/use_case/kws/KWSHandlerTest.cc +++ b/tests/use_case/kws/KWSHandlerTest.cc @@ -15,7 +15,7 @@ * limitations under the License. */ #include <catch.hpp> -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" #include "KwsResult.hpp" @@ -27,7 +27,7 @@ TEST_CASE("Model info") { /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -53,7 +53,7 @@ TEST_CASE("Inference by index") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -65,8 +65,8 @@ TEST_CASE("Inference by index") caseContext.Set<arm::app::Profiler&>("profiler", profiler); caseContext.Set<hal_platform&>("platform", platform); caseContext.Set<arm::app::Model&>("model", model); - caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */ - caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */ + caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNetKws. */ + caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNetKws. */ caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */ arm::app::Classifier classifier; /* classifier wrapper object. */ @@ -97,25 +97,25 @@ TEST_CASE("Inference by index") SECTION("Index = 0, short clip down") { /* Result: down. */ - checker(0, {5}); + checker(0, {0}); } SECTION("Index = 1, long clip right->left->up") { /* Result: right->right->left->up->up. */ - checker(1, {7, 1, 6, 4, 4}); + checker(1, {6, 6, 2, 8, 8}); } SECTION("Index = 2, short clip yes") { /* Result: yes. */ - checker(2, {2}); + checker(2, {9}); } SECTION("Index = 3, long clip yes->no->go->stop") { - /* Result: yes->go->no->go->go->go->stop. */ - checker(3, {2, 11, 3, 11, 11, 11, 10}); + /* Result: yes->no->no->go->go->stop->stop. */ + checker(3, {9, 3, 3, 1, 1, 7, 7}); } } @@ -132,7 +132,7 @@ TEST_CASE("Inference run all clips") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -145,9 +145,9 @@ TEST_CASE("Inference run all clips") caseContext.Set<hal_platform&>("platform", platform); caseContext.Set<arm::app::Model&>("model", model); caseContext.Set<uint32_t>("clipIndex", 0); - caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */ - caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */ - caseContext.Set<float>("scoreThreshold", 0.9); /* Normalised score threshold. */ + caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNet. */ + caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNet. */ + caseContext.Set<float>("scoreThreshold", 0.7); /* Normalised score threshold. */ arm::app::Classifier classifier; /* classifier wrapper object. */ caseContext.Set<arm::app::Classifier&>("classifier", classifier); @@ -170,7 +170,7 @@ TEST_CASE("List all audio clips") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); diff --git a/tests/use_case/kws/MfccTests.cc b/tests/use_case/kws/MfccTests.cc index 407861f..1d30ef4 100644 --- a/tests/use_case/kws/MfccTests.cc +++ b/tests/use_case/kws/MfccTests.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include <algorithm> #include <catch.hpp> @@ -93,13 +93,13 @@ const std::vector<float> testWavMfcc { -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, }; -arm::app::audio::DsCnnMFCC GetMFCCInstance() { - const int sampFreq = arm::app::audio::DsCnnMFCC::ms_defaultSamplingFreq; +arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() { + const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; const int frameLenMs = 40; const int frameLenSamples = sampFreq * frameLenMs * 0.001; const int numMfccFeats = 10; - return arm::app::audio::DsCnnMFCC(numMfccFeats, frameLenSamples); + return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples); } template <class T> |