summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/common/ClassifierTests.cc4
-rw-r--r--tests/use_case/kws/InferenceTestMicroNetKws.cc (renamed from tests/use_case/kws/InferenceTestDSCNN.cc)10
-rw-r--r--tests/use_case/kws/KWSHandlerTest.cc30
-rw-r--r--tests/use_case/kws/MfccTests.cc8
-rw-r--r--tests/use_case/kws_asr/InferenceTestMicroNetKws.cc (renamed from tests/use_case/kws_asr/InferenceTestDSCNN.cc)10
-rw-r--r--tests/use_case/kws_asr/InitModels.cc6
-rw-r--r--tests/use_case/kws_asr/MfccTests.cc22
7 files changed, 45 insertions, 45 deletions
diff --git a/tests/common/ClassifierTests.cc b/tests/common/ClassifierTests.cc
index d950304..693f744 100644
--- a/tests/common/ClassifierTests.cc
+++ b/tests/common/ClassifierTests.cc
@@ -35,7 +35,7 @@ void test_classifier_result(std::vector<std::pair<uint32_t, T>>& selectedResults
}
arm::app::Classifier classifier;
- REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5));
+ REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5, true));
REQUIRE(5 == resultVec.size());
for (size_t i = 0; i < resultVec.size(); ++i) {
@@ -50,7 +50,7 @@ TEST_CASE("Common classifier")
TfLiteTensor* outputTens = nullptr;
std::vector <arm::app::ClassificationResult> resultVec;
arm::app::Classifier classifier;
- REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5));
+ REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5, true));
}
SECTION("Test classification results")
diff --git a/tests/use_case/kws/InferenceTestDSCNN.cc b/tests/use_case/kws/InferenceTestMicroNetKws.cc
index 8918073..e6e7753 100644
--- a/tests/use_case/kws/InferenceTestDSCNN.cc
+++ b/tests/use_case/kws/InferenceTestMicroNetKws.cc
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "hal.h"
#include "TestData_kws.hpp"
#include "TensorFlowLiteMicro.hpp"
@@ -74,9 +74,9 @@ void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::
}
}
-TEST_CASE("Running random inference with TensorFlow Lite Micro and DsCnnModel Int8", "[DS_CNN]")
+TEST_CASE("Running random inference with TensorFlow Lite Micro and MicroNetKwsModel Int8", "[MicroNetKws]")
{
- arm::app::DsCnnModel model{};
+ arm::app::MicroNetKwsModel model{};
REQUIRE_FALSE(model.IsInited());
REQUIRE(model.Init());
@@ -85,7 +85,7 @@ TEST_CASE("Running random inference with TensorFlow Lite Micro and DsCnnModel In
REQUIRE(RunInferenceRandom(model));
}
-TEST_CASE("Running inference with TensorFlow Lite Micro and DsCnnModel Uint8", "[DS_CNN]")
+TEST_CASE("Running inference with TensorFlow Lite Micro and MicroNetKwsModel int8", "[MicroNetKws]")
{
REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES);
for (uint32_t i = 0 ; i < NUMBER_OF_IFM_FILES; ++i) {
@@ -94,7 +94,7 @@ TEST_CASE("Running inference with TensorFlow Lite Micro and DsCnnModel Uint8", "
DYNAMIC_SECTION("Executing inference with re-init")
{
- arm::app::DsCnnModel model{};
+ arm::app::MicroNetKwsModel model{};
REQUIRE_FALSE(model.IsInited());
REQUIRE(model.Init());
diff --git a/tests/use_case/kws/KWSHandlerTest.cc b/tests/use_case/kws/KWSHandlerTest.cc
index 50e5a83..a7e75fb 100644
--- a/tests/use_case/kws/KWSHandlerTest.cc
+++ b/tests/use_case/kws/KWSHandlerTest.cc
@@ -15,7 +15,7 @@
* limitations under the License.
*/
#include <catch.hpp>
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "hal.h"
#include "KwsResult.hpp"
@@ -27,7 +27,7 @@
TEST_CASE("Model info")
{
/* Model wrapper object. */
- arm::app::DsCnnModel model;
+ arm::app::MicroNetKwsModel model;
/* Load the model. */
REQUIRE(model.Init());
@@ -53,7 +53,7 @@ TEST_CASE("Inference by index")
hal_platform_init(&platform);
/* Model wrapper object. */
- arm::app::DsCnnModel model;
+ arm::app::MicroNetKwsModel model;
/* Load the model. */
REQUIRE(model.Init());
@@ -65,8 +65,8 @@ TEST_CASE("Inference by index")
caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
- caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */
- caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */
+ caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNetKws. */
+ caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNetKws. */
caseContext.Set<float>("scoreThreshold", 0.5); /* Normalised score threshold. */
arm::app::Classifier classifier; /* classifier wrapper object. */
@@ -97,25 +97,25 @@ TEST_CASE("Inference by index")
SECTION("Index = 0, short clip down")
{
/* Result: down. */
- checker(0, {5});
+ checker(0, {0});
}
SECTION("Index = 1, long clip right->left->up")
{
/* Result: right->right->left->up->up. */
- checker(1, {7, 1, 6, 4, 4});
+ checker(1, {6, 6, 2, 8, 8});
}
SECTION("Index = 2, short clip yes")
{
/* Result: yes. */
- checker(2, {2});
+ checker(2, {9});
}
SECTION("Index = 3, long clip yes->no->go->stop")
{
- /* Result: yes->go->no->go->go->go->stop. */
- checker(3, {2, 11, 3, 11, 11, 11, 10});
+ /* Result: yes->no->no->go->go->stop->stop. */
+ checker(3, {9, 3, 3, 1, 1, 7, 7});
}
}
@@ -132,7 +132,7 @@ TEST_CASE("Inference run all clips")
hal_platform_init(&platform);
/* Model wrapper object. */
- arm::app::DsCnnModel model;
+ arm::app::MicroNetKwsModel model;
/* Load the model. */
REQUIRE(model.Init());
@@ -145,9 +145,9 @@ TEST_CASE("Inference run all clips")
caseContext.Set<hal_platform&>("platform", platform);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
- caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */
- caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */
- caseContext.Set<float>("scoreThreshold", 0.9); /* Normalised score threshold. */
+ caseContext.Set<int>("frameLength", g_FrameLength); /* 640 sample length for MicroNet. */
+ caseContext.Set<int>("frameStride", g_FrameStride); /* 320 sample stride for MicroNet. */
+ caseContext.Set<float>("scoreThreshold", 0.7); /* Normalised score threshold. */
arm::app::Classifier classifier; /* classifier wrapper object. */
caseContext.Set<arm::app::Classifier&>("classifier", classifier);
@@ -170,7 +170,7 @@ TEST_CASE("List all audio clips")
hal_platform_init(&platform);
/* Model wrapper object. */
- arm::app::DsCnnModel model;
+ arm::app::MicroNetKwsModel model;
/* Load the model. */
REQUIRE(model.Init());
diff --git a/tests/use_case/kws/MfccTests.cc b/tests/use_case/kws/MfccTests.cc
index 407861f..1d30ef4 100644
--- a/tests/use_case/kws/MfccTests.cc
+++ b/tests/use_case/kws/MfccTests.cc
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "DsCnnMfcc.hpp"
+#include "MicroNetKwsMfcc.hpp"
#include <algorithm>
#include <catch.hpp>
@@ -93,13 +93,13 @@ const std::vector<float> testWavMfcc {
-22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072,
};
-arm::app::audio::DsCnnMFCC GetMFCCInstance() {
- const int sampFreq = arm::app::audio::DsCnnMFCC::ms_defaultSamplingFreq;
+arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() {
+ const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
const int frameLenMs = 40;
const int frameLenSamples = sampFreq * frameLenMs * 0.001;
const int numMfccFeats = 10;
- return arm::app::audio::DsCnnMFCC(numMfccFeats, frameLenSamples);
+ return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples);
}
template <class T>
diff --git a/tests/use_case/kws_asr/InferenceTestDSCNN.cc b/tests/use_case/kws_asr/InferenceTestMicroNetKws.cc
index ad1731b..fd379b6 100644
--- a/tests/use_case/kws_asr/InferenceTestDSCNN.cc
+++ b/tests/use_case/kws_asr/InferenceTestMicroNetKws.cc
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "hal.h"
#include "TestData_kws.hpp"
#include "TensorFlowLiteMicro.hpp"
@@ -72,8 +72,8 @@ void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::
}
}
-TEST_CASE("Running random inference with Tflu and DsCnnModel Int8", "[DS_CNN]") {
- arm::app::DsCnnModel model{};
+TEST_CASE("Running random inference with Tflu and MicroNetKwsModel Int8", "[MicroNetKws]") {
+ arm::app::MicroNetKwsModel model{};
REQUIRE_FALSE(model.IsInited());
REQUIRE(model.Init());
@@ -82,14 +82,14 @@ TEST_CASE("Running random inference with Tflu and DsCnnModel Int8", "[DS_CNN]")
REQUIRE(RunInferenceRandom(model));
}
-TEST_CASE("Running inference with Tflu and DsCnnModel Uint8", "[DS_CNN]") {
+TEST_CASE("Running inference with Tflu and MicroNetKwsModel Int8", "[MicroNetKws]") {
REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES);
for (uint32_t i = 0; i < NUMBER_OF_IFM_FILES; ++i) {
const int8_t* input_goldenFV = get_ifm_data_array(i);
const int8_t* output_goldenFV = get_ofm_data_array(i);
DYNAMIC_SECTION("Executing inference with re-init") {
- arm::app::DsCnnModel model{};
+ arm::app::MicroNetKwsModel model{};
REQUIRE_FALSE(model.IsInited());
REQUIRE(model.Init());
diff --git a/tests/use_case/kws_asr/InitModels.cc b/tests/use_case/kws_asr/InitModels.cc
index 770944d..97aa092 100644
--- a/tests/use_case/kws_asr/InitModels.cc
+++ b/tests/use_case/kws_asr/InitModels.cc
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "Wav2LetterModel.hpp"
#include <catch.hpp>
@@ -22,8 +22,8 @@
/* Skip this test, Wav2LetterModel if not Vela optimized but only from ML-zoo will fail. */
TEST_CASE("Init two Models", "[.]")
{
- arm::app::DsCnnModel model1;
- arm::app::DsCnnModel model2;
+ arm::app::MicroNetKwsModel model1;
+ arm::app::MicroNetKwsModel model2;
/* Ideally we should load the wav2letter model here, but there is
* none available to run on native (ops not supported on unoptimised
diff --git a/tests/use_case/kws_asr/MfccTests.cc b/tests/use_case/kws_asr/MfccTests.cc
index 9509519..c0fb723 100644
--- a/tests/use_case/kws_asr/MfccTests.cc
+++ b/tests/use_case/kws_asr/MfccTests.cc
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "DsCnnMfcc.hpp"
+#include "MicroNetKwsMfcc.hpp"
#include <algorithm>
#include <catch.hpp>
@@ -93,17 +93,17 @@ const std::vector<float> testWavMfcc {
-22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072,
};
-arm::app::audio::DsCnnMFCC GetMFCCInstance() {
- const int sampFreq = arm::app::audio::DsCnnMFCC::ms_defaultSamplingFreq;
+arm::app::audio::MicroNetMFCC GetMFCCInstance() {
+ const int sampFreq = arm::app::audio::MicroNetMFCC::ms_defaultSamplingFreq;
const int frameLenMs = 40;
const int frameLenSamples = sampFreq * frameLenMs * 0.001;
const int numMfccFeats = 10;
- return arm::app::audio::DsCnnMFCC(numMfccFeats, frameLenSamples);
+ return arm::app::audio::MicroNetMFCC(numMfccFeats, frameLenSamples);
}
template <class T>
-void TestQuntisedMFCC() {
+void TestQuantisedMFCC() {
const float quantScale = 1.1088106632232666;
const int quantOffset = 95;
std::vector<T> mfccOutput = GetMFCCInstance().MfccComputeQuant<T>(testWav, quantScale, quantOffset);
@@ -118,9 +118,9 @@ void TestQuntisedMFCC() {
REQUIRE(quantizedTestWavMfcc == Approx(mfccOutput[i]).margin(0));
}
}
-template void TestQuntisedMFCC<int8_t>();
-template void TestQuntisedMFCC<uint8_t>();
-template void TestQuntisedMFCC<int16_t>();
+template void TestQuantisedMFCC<int8_t>();
+template void TestQuantisedMFCC<uint8_t>();
+template void TestQuantisedMFCC<int16_t>();
TEST_CASE("MFCC calculation test")
{
@@ -141,16 +141,16 @@ TEST_CASE("MFCC calculation test")
SECTION("int8_t")
{
- TestQuntisedMFCC<int8_t>();
+ TestQuantisedMFCC<int8_t>();
}
SECTION("uint8_t")
{
- TestQuntisedMFCC<uint8_t>();
+ TestQuantisedMFCC<uint8_t>();
}
SECTION("MFCC quant calculation test - int16_t")
{
- TestQuntisedMFCC<int16_t>();
+ TestQuantisedMFCC<int16_t>();
}
} \ No newline at end of file