From 76a1580861210e0310db23acbc29e1064ae30ead Mon Sep 17 00:00:00 2001 From: Kshitij Sisodia Date: Fri, 24 Dec 2021 11:05:11 +0000 Subject: MLECO-2599: Replace DSCNN with MicroNet for KWS Added SoftMax function to Mathutils to allow MicroNet to output probability as it does not nativelu have this layer. Minor refactoring to accommodate Softmax Calculations Extensive renaming and updating of documentation and resource download script. Added SoftMax function to Mathutils to allow MicroNet to output probability. Change-Id: I7cbbda1024d14b85c9ac1beea7ca8fbffd0b6eb5 Signed-off-by: Liam Barry --- tests/common/ClassifierTests.cc | 4 +- tests/use_case/kws/InferenceTestDSCNN.cc | 107 --------------------- tests/use_case/kws/InferenceTestMicroNetKws.cc | 107 +++++++++++++++++++++ tests/use_case/kws/KWSHandlerTest.cc | 30 +++--- tests/use_case/kws/MfccTests.cc | 8 +- tests/use_case/kws_asr/InferenceTestDSCNN.cc | 105 -------------------- tests/use_case/kws_asr/InferenceTestMicroNetKws.cc | 105 ++++++++++++++++++++ tests/use_case/kws_asr/InitModels.cc | 6 +- tests/use_case/kws_asr/MfccTests.cc | 22 ++--- 9 files changed, 247 insertions(+), 247 deletions(-) delete mode 100644 tests/use_case/kws/InferenceTestDSCNN.cc create mode 100644 tests/use_case/kws/InferenceTestMicroNetKws.cc delete mode 100644 tests/use_case/kws_asr/InferenceTestDSCNN.cc create mode 100644 tests/use_case/kws_asr/InferenceTestMicroNetKws.cc (limited to 'tests') diff --git a/tests/common/ClassifierTests.cc b/tests/common/ClassifierTests.cc index d950304..693f744 100644 --- a/tests/common/ClassifierTests.cc +++ b/tests/common/ClassifierTests.cc @@ -35,7 +35,7 @@ void test_classifier_result(std::vector>& selectedResults } arm::app::Classifier classifier; - REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5)); + REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5, true)); REQUIRE(5 == resultVec.size()); for (size_t i = 0; i < resultVec.size(); ++i) { @@ -50,7 +50,7 @@ TEST_CASE("Common classifier") TfLiteTensor* outputTens = nullptr; std::vector resultVec; arm::app::Classifier classifier; - REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5)); + REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5, true)); } SECTION("Test classification results") diff --git a/tests/use_case/kws/InferenceTestDSCNN.cc b/tests/use_case/kws/InferenceTestDSCNN.cc deleted file mode 100644 index 8918073..0000000 --- a/tests/use_case/kws/InferenceTestDSCNN.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "DsCnnModel.hpp" -#include "hal.h" -#include "TestData_kws.hpp" -#include "TensorFlowLiteMicro.hpp" - -#include -#include - -using namespace test; - -bool RunInference(arm::app::Model& model, const int8_t vec[]) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - REQUIRE(inputTensor); - - const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? - inputTensor->bytes : - IFM_0_DATA_SIZE; - memcpy(inputTensor->data.data, vec, copySz); - - return model.RunInference(); -} - -bool RunInferenceRandom(arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - REQUIRE(inputTensor); - - std::random_device rndDevice; - std::mt19937 mersenneGen{rndDevice()}; - std::uniform_int_distribution dist {-128, 127}; - - auto gen = [&dist, &mersenneGen](){ - return dist(mersenneGen); - }; - - std::vector randomAudio(inputTensor->bytes); - std::generate(std::begin(randomAudio), std::end(randomAudio), gen); - - REQUIRE(RunInference(model, randomAudio.data())); - return true; -} - -template -void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) -{ - REQUIRE(RunInference(model, input_goldenFV)); - - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - - REQUIRE(outputTensor); - REQUIRE(outputTensor->bytes == OFM_0_DATA_SIZE); - auto tensorData = tflite::GetTensorData(outputTensor); - REQUIRE(tensorData); - - for (size_t i = 0; i < outputTensor->bytes; i++) { - REQUIRE(static_cast(tensorData[i]) == static_cast(((T)output_goldenFV[i]))); - } -} - -TEST_CASE("Running random inference with TensorFlow Lite Micro and DsCnnModel Int8", "[DS_CNN]") -{ - arm::app::DsCnnModel model{}; - - REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); - REQUIRE(model.IsInited()); - - REQUIRE(RunInferenceRandom(model)); -} - -TEST_CASE("Running inference with TensorFlow Lite Micro and DsCnnModel Uint8", "[DS_CNN]") -{ - REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); - for (uint32_t i = 0 ; i < NUMBER_OF_IFM_FILES; ++i) { - const int8_t* input_goldenFV = get_ifm_data_array(i);; - const int8_t* output_goldenFV = get_ofm_data_array(i); - - DYNAMIC_SECTION("Executing inference with re-init") - { - arm::app::DsCnnModel model{}; - - REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); - REQUIRE(model.IsInited()); - - TestInference(input_goldenFV, output_goldenFV, model); - - } - } -} diff --git a/tests/use_case/kws/InferenceTestMicroNetKws.cc b/tests/use_case/kws/InferenceTestMicroNetKws.cc new file mode 100644 index 0000000..e6e7753 --- /dev/null +++ b/tests/use_case/kws/InferenceTestMicroNetKws.cc @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MicroNetKwsModel.hpp" +#include "hal.h" +#include "TestData_kws.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include +#include + +using namespace test; + +bool RunInference(arm::app::Model& model, const int8_t vec[]) +{ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + REQUIRE(inputTensor); + + const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? + inputTensor->bytes : + IFM_0_DATA_SIZE; + memcpy(inputTensor->data.data, vec, copySz); + + return model.RunInference(); +} + +bool RunInferenceRandom(arm::app::Model& model) +{ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + REQUIRE(inputTensor); + + std::random_device rndDevice; + std::mt19937 mersenneGen{rndDevice()}; + std::uniform_int_distribution dist {-128, 127}; + + auto gen = [&dist, &mersenneGen](){ + return dist(mersenneGen); + }; + + std::vector randomAudio(inputTensor->bytes); + std::generate(std::begin(randomAudio), std::end(randomAudio), gen); + + REQUIRE(RunInference(model, randomAudio.data())); + return true; +} + +template +void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) +{ + REQUIRE(RunInference(model, input_goldenFV)); + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + + REQUIRE(outputTensor); + REQUIRE(outputTensor->bytes == OFM_0_DATA_SIZE); + auto tensorData = tflite::GetTensorData(outputTensor); + REQUIRE(tensorData); + + for (size_t i = 0; i < outputTensor->bytes; i++) { + REQUIRE(static_cast(tensorData[i]) == static_cast(((T)output_goldenFV[i]))); + } +} + +TEST_CASE("Running random inference with TensorFlow Lite Micro and MicroNetKwsModel Int8", "[MicroNetKws]") +{ + arm::app::MicroNetKwsModel model{}; + + REQUIRE_FALSE(model.IsInited()); + REQUIRE(model.Init()); + REQUIRE(model.IsInited()); + + REQUIRE(RunInferenceRandom(model)); +} + +TEST_CASE("Running inference with TensorFlow Lite Micro and MicroNetKwsModel int8", "[MicroNetKws]") +{ + REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); + for (uint32_t i = 0 ; i < NUMBER_OF_IFM_FILES; ++i) { + const int8_t* input_goldenFV = get_ifm_data_array(i);; + const int8_t* output_goldenFV = get_ofm_data_array(i); + + DYNAMIC_SECTION("Executing inference with re-init") + { + arm::app::MicroNetKwsModel model{}; + + REQUIRE_FALSE(model.IsInited()); + REQUIRE(model.Init()); + REQUIRE(model.IsInited()); + + TestInference(input_goldenFV, output_goldenFV, model); + + } + } +} diff --git a/tests/use_case/kws/KWSHandlerTest.cc b/tests/use_case/kws/KWSHandlerTest.cc index 50e5a83..a7e75fb 100644 --- a/tests/use_case/kws/KWSHandlerTest.cc +++ b/tests/use_case/kws/KWSHandlerTest.cc @@ -15,7 +15,7 @@ * limitations under the License. */ #include -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" #include "KwsResult.hpp" @@ -27,7 +27,7 @@ TEST_CASE("Model info") { /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -53,7 +53,7 @@ TEST_CASE("Inference by index") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -65,8 +65,8 @@ TEST_CASE("Inference by index") caseContext.Set("profiler", profiler); caseContext.Set("platform", platform); caseContext.Set("model", model); - caseContext.Set("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */ - caseContext.Set("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */ + caseContext.Set("frameLength", g_FrameLength); /* 640 sample length for MicroNetKws. */ + caseContext.Set("frameStride", g_FrameStride); /* 320 sample stride for MicroNetKws. */ caseContext.Set("scoreThreshold", 0.5); /* Normalised score threshold. */ arm::app::Classifier classifier; /* classifier wrapper object. */ @@ -97,25 +97,25 @@ TEST_CASE("Inference by index") SECTION("Index = 0, short clip down") { /* Result: down. */ - checker(0, {5}); + checker(0, {0}); } SECTION("Index = 1, long clip right->left->up") { /* Result: right->right->left->up->up. */ - checker(1, {7, 1, 6, 4, 4}); + checker(1, {6, 6, 2, 8, 8}); } SECTION("Index = 2, short clip yes") { /* Result: yes. */ - checker(2, {2}); + checker(2, {9}); } SECTION("Index = 3, long clip yes->no->go->stop") { - /* Result: yes->go->no->go->go->go->stop. */ - checker(3, {2, 11, 3, 11, 11, 11, 10}); + /* Result: yes->no->no->go->go->stop->stop. */ + checker(3, {9, 3, 3, 1, 1, 7, 7}); } } @@ -132,7 +132,7 @@ TEST_CASE("Inference run all clips") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -145,9 +145,9 @@ TEST_CASE("Inference run all clips") caseContext.Set("platform", platform); caseContext.Set("model", model); caseContext.Set("clipIndex", 0); - caseContext.Set("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */ - caseContext.Set("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */ - caseContext.Set("scoreThreshold", 0.9); /* Normalised score threshold. */ + caseContext.Set("frameLength", g_FrameLength); /* 640 sample length for MicroNet. */ + caseContext.Set("frameStride", g_FrameStride); /* 320 sample stride for MicroNet. */ + caseContext.Set("scoreThreshold", 0.7); /* Normalised score threshold. */ arm::app::Classifier classifier; /* classifier wrapper object. */ caseContext.Set("classifier", classifier); @@ -170,7 +170,7 @@ TEST_CASE("List all audio clips") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); diff --git a/tests/use_case/kws/MfccTests.cc b/tests/use_case/kws/MfccTests.cc index 407861f..1d30ef4 100644 --- a/tests/use_case/kws/MfccTests.cc +++ b/tests/use_case/kws/MfccTests.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include #include @@ -93,13 +93,13 @@ const std::vector testWavMfcc { -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, }; -arm::app::audio::DsCnnMFCC GetMFCCInstance() { - const int sampFreq = arm::app::audio::DsCnnMFCC::ms_defaultSamplingFreq; +arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() { + const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; const int frameLenMs = 40; const int frameLenSamples = sampFreq * frameLenMs * 0.001; const int numMfccFeats = 10; - return arm::app::audio::DsCnnMFCC(numMfccFeats, frameLenSamples); + return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples); } template diff --git a/tests/use_case/kws_asr/InferenceTestDSCNN.cc b/tests/use_case/kws_asr/InferenceTestDSCNN.cc deleted file mode 100644 index ad1731b..0000000 --- a/tests/use_case/kws_asr/InferenceTestDSCNN.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "DsCnnModel.hpp" -#include "hal.h" -#include "TestData_kws.hpp" -#include "TensorFlowLiteMicro.hpp" - -#include -#include - -namespace test { -namespace kws { - -bool RunInference(arm::app::Model& model, const int8_t vec[]) { - TfLiteTensor* inputTensor = model.GetInputTensor(0); - REQUIRE(inputTensor); - - const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? - inputTensor->bytes : - IFM_0_DATA_SIZE; - memcpy(inputTensor->data.data, vec, copySz); - - return model.RunInference(); -} - -bool RunInferenceRandom(arm::app::Model& model) { - TfLiteTensor* inputTensor = model.GetInputTensor(0); - REQUIRE(inputTensor); - - std::random_device rndDevice; - std::mt19937 mersenneGen{rndDevice()}; - std::uniform_int_distribution dist{-128, 127}; - - auto gen = [&dist, &mersenneGen]() { - return dist(mersenneGen); - }; - - std::vector randomAudio(inputTensor->bytes); - std::generate(std::begin(randomAudio), std::end(randomAudio), gen); - - REQUIRE(RunInference(model, randomAudio.data())); - return true; -} - -template -void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) { - REQUIRE(RunInference(model, input_goldenFV)); - - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - - REQUIRE(outputTensor); - REQUIRE(outputTensor->bytes == OFM_0_DATA_SIZE); - auto tensorData = tflite::GetTensorData(outputTensor); - REQUIRE(tensorData); - - for (size_t i = 0; i < outputTensor->bytes; i++) { - REQUIRE(static_cast(tensorData[i]) == static_cast((T) output_goldenFV[i])); - } -} - -TEST_CASE("Running random inference with Tflu and DsCnnModel Int8", "[DS_CNN]") { - arm::app::DsCnnModel model{}; - - REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); - REQUIRE(model.IsInited()); - - REQUIRE(RunInferenceRandom(model)); -} - -TEST_CASE("Running inference with Tflu and DsCnnModel Uint8", "[DS_CNN]") { - REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); - for (uint32_t i = 0; i < NUMBER_OF_IFM_FILES; ++i) { - const int8_t* input_goldenFV = get_ifm_data_array(i); - const int8_t* output_goldenFV = get_ofm_data_array(i); - - DYNAMIC_SECTION("Executing inference with re-init") { - arm::app::DsCnnModel model{}; - - REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); - REQUIRE(model.IsInited()); - - TestInference(input_goldenFV, output_goldenFV, model); - - } - } -} - -} //namespace -} //namespace \ No newline at end of file diff --git a/tests/use_case/kws_asr/InferenceTestMicroNetKws.cc b/tests/use_case/kws_asr/InferenceTestMicroNetKws.cc new file mode 100644 index 0000000..fd379b6 --- /dev/null +++ b/tests/use_case/kws_asr/InferenceTestMicroNetKws.cc @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MicroNetKwsModel.hpp" +#include "hal.h" +#include "TestData_kws.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include +#include + +namespace test { +namespace kws { + +bool RunInference(arm::app::Model& model, const int8_t vec[]) { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + REQUIRE(inputTensor); + + const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? + inputTensor->bytes : + IFM_0_DATA_SIZE; + memcpy(inputTensor->data.data, vec, copySz); + + return model.RunInference(); +} + +bool RunInferenceRandom(arm::app::Model& model) { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + REQUIRE(inputTensor); + + std::random_device rndDevice; + std::mt19937 mersenneGen{rndDevice()}; + std::uniform_int_distribution dist{-128, 127}; + + auto gen = [&dist, &mersenneGen]() { + return dist(mersenneGen); + }; + + std::vector randomAudio(inputTensor->bytes); + std::generate(std::begin(randomAudio), std::end(randomAudio), gen); + + REQUIRE(RunInference(model, randomAudio.data())); + return true; +} + +template +void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) { + REQUIRE(RunInference(model, input_goldenFV)); + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + + REQUIRE(outputTensor); + REQUIRE(outputTensor->bytes == OFM_0_DATA_SIZE); + auto tensorData = tflite::GetTensorData(outputTensor); + REQUIRE(tensorData); + + for (size_t i = 0; i < outputTensor->bytes; i++) { + REQUIRE(static_cast(tensorData[i]) == static_cast((T) output_goldenFV[i])); + } +} + +TEST_CASE("Running random inference with Tflu and MicroNetKwsModel Int8", "[MicroNetKws]") { + arm::app::MicroNetKwsModel model{}; + + REQUIRE_FALSE(model.IsInited()); + REQUIRE(model.Init()); + REQUIRE(model.IsInited()); + + REQUIRE(RunInferenceRandom(model)); +} + +TEST_CASE("Running inference with Tflu and MicroNetKwsModel Int8", "[MicroNetKws]") { + REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); + for (uint32_t i = 0; i < NUMBER_OF_IFM_FILES; ++i) { + const int8_t* input_goldenFV = get_ifm_data_array(i); + const int8_t* output_goldenFV = get_ofm_data_array(i); + + DYNAMIC_SECTION("Executing inference with re-init") { + arm::app::MicroNetKwsModel model{}; + + REQUIRE_FALSE(model.IsInited()); + REQUIRE(model.Init()); + REQUIRE(model.IsInited()); + + TestInference(input_goldenFV, output_goldenFV, model); + + } + } +} + +} //namespace +} //namespace \ No newline at end of file diff --git a/tests/use_case/kws_asr/InitModels.cc b/tests/use_case/kws_asr/InitModels.cc index 770944d..97aa092 100644 --- a/tests/use_case/kws_asr/InitModels.cc +++ b/tests/use_case/kws_asr/InitModels.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "Wav2LetterModel.hpp" #include @@ -22,8 +22,8 @@ /* Skip this test, Wav2LetterModel if not Vela optimized but only from ML-zoo will fail. */ TEST_CASE("Init two Models", "[.]") { - arm::app::DsCnnModel model1; - arm::app::DsCnnModel model2; + arm::app::MicroNetKwsModel model1; + arm::app::MicroNetKwsModel model2; /* Ideally we should load the wav2letter model here, but there is * none available to run on native (ops not supported on unoptimised diff --git a/tests/use_case/kws_asr/MfccTests.cc b/tests/use_case/kws_asr/MfccTests.cc index 9509519..c0fb723 100644 --- a/tests/use_case/kws_asr/MfccTests.cc +++ b/tests/use_case/kws_asr/MfccTests.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include #include @@ -93,17 +93,17 @@ const std::vector testWavMfcc { -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, }; -arm::app::audio::DsCnnMFCC GetMFCCInstance() { - const int sampFreq = arm::app::audio::DsCnnMFCC::ms_defaultSamplingFreq; +arm::app::audio::MicroNetMFCC GetMFCCInstance() { + const int sampFreq = arm::app::audio::MicroNetMFCC::ms_defaultSamplingFreq; const int frameLenMs = 40; const int frameLenSamples = sampFreq * frameLenMs * 0.001; const int numMfccFeats = 10; - return arm::app::audio::DsCnnMFCC(numMfccFeats, frameLenSamples); + return arm::app::audio::MicroNetMFCC(numMfccFeats, frameLenSamples); } template -void TestQuntisedMFCC() { +void TestQuantisedMFCC() { const float quantScale = 1.1088106632232666; const int quantOffset = 95; std::vector mfccOutput = GetMFCCInstance().MfccComputeQuant(testWav, quantScale, quantOffset); @@ -118,9 +118,9 @@ void TestQuntisedMFCC() { REQUIRE(quantizedTestWavMfcc == Approx(mfccOutput[i]).margin(0)); } } -template void TestQuntisedMFCC(); -template void TestQuntisedMFCC(); -template void TestQuntisedMFCC(); +template void TestQuantisedMFCC(); +template void TestQuantisedMFCC(); +template void TestQuantisedMFCC(); TEST_CASE("MFCC calculation test") { @@ -141,16 +141,16 @@ TEST_CASE("MFCC calculation test") SECTION("int8_t") { - TestQuntisedMFCC(); + TestQuantisedMFCC(); } SECTION("uint8_t") { - TestQuntisedMFCC(); + TestQuantisedMFCC(); } SECTION("MFCC quant calculation test - int16_t") { - TestQuntisedMFCC(); + TestQuantisedMFCC(); } } \ No newline at end of file -- cgit v1.2.1