summaryrefslogtreecommitdiff
path: root/source/use_case/kws/src
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/kws/src')
-rw-r--r--source/use_case/kws/src/MainLoop.cc4
-rw-r--r--source/use_case/kws/src/MicroNetKwsModel.cc (renamed from source/use_case/kws/src/DsCnnModel.cc)11
-rw-r--r--source/use_case/kws/src/UseCaseHandler.cc24
3 files changed, 18 insertions, 21 deletions
diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc
index c683e71..bde246b 100644
--- a/source/use_case/kws/src/MainLoop.cc
+++ b/source/use_case/kws/src/MainLoop.cc
@@ -16,7 +16,7 @@
*/
#include "InputFiles.hpp" /* For input audio clips. */
#include "Classifier.hpp" /* Classifier. */
-#include "DsCnnModel.hpp" /* Model class for running inference. */
+#include "MicroNetKwsModel.hpp" /* Model class for running inference. */
#include "hal.h" /* Brings in platform definitions. */
#include "Labels.hpp" /* For label strings. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
@@ -49,7 +49,7 @@ static void DisplayMenu()
void main_loop(hal_platform& platform)
{
- arm::app::DsCnnModel model; /* Model wrapper object. */
+ arm::app::MicroNetKwsModel model; /* Model wrapper object. */
/* Load the model. */
if (!model.Init()) {
diff --git a/source/use_case/kws/src/DsCnnModel.cc b/source/use_case/kws/src/MicroNetKwsModel.cc
index 4edfc04..48a9b8c 100644
--- a/source/use_case/kws/src/DsCnnModel.cc
+++ b/source/use_case/kws/src/MicroNetKwsModel.cc
@@ -14,16 +14,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "hal.h"
-const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver()
+const tflite::MicroOpResolver& arm::app::MicroNetKwsModel::GetOpResolver()
{
return this->m_opResolver;
}
-bool arm::app::DsCnnModel::EnlistOperations()
+bool arm::app::MicroNetKwsModel::EnlistOperations()
{
this->m_opResolver.AddReshape();
this->m_opResolver.AddAveragePool2D();
@@ -31,7 +31,6 @@ bool arm::app::DsCnnModel::EnlistOperations()
this->m_opResolver.AddDepthwiseConv2D();
this->m_opResolver.AddFullyConnected();
this->m_opResolver.AddRelu();
- this->m_opResolver.AddSoftmax();
#if defined(ARM_NPU)
if (kTfLiteOk == this->m_opResolver.AddEthosU()) {
@@ -46,13 +45,13 @@ bool arm::app::DsCnnModel::EnlistOperations()
}
extern uint8_t* GetModelPointer();
-const uint8_t* arm::app::DsCnnModel::ModelPointer()
+const uint8_t* arm::app::MicroNetKwsModel::ModelPointer()
{
return GetModelPointer();
}
extern size_t GetModelLen();
-size_t arm::app::DsCnnModel::ModelSize()
+size_t arm::app::MicroNetKwsModel::ModelSize()
{
return GetModelLen();
} \ No newline at end of file
diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc
index 3d95753..8085af7 100644
--- a/source/use_case/kws/src/UseCaseHandler.cc
+++ b/source/use_case/kws/src/UseCaseHandler.cc
@@ -18,9 +18,9 @@
#include "InputFiles.hpp"
#include "Classifier.hpp"
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "hal.h"
-#include "DsCnnMfcc.hpp"
+#include "MicroNetKwsMfcc.hpp"
#include "AudioUtils.hpp"
#include "UseCaseCommonUtils.hpp"
#include "KwsResult.hpp"
@@ -59,7 +59,7 @@ namespace app {
* @return Function to be called providing audio sample and sliding window index.
*/
static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
+ GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
TfLiteTensor* inputTensor,
size_t cacheSize);
@@ -72,8 +72,8 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartX = 20;
constexpr uint32_t dataPsnTxtInfStartY = 40;
constexpr int minTensorDims = static_cast<int>(
- (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
- arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
+ (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
+ arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
auto& model = ctx.Get<Model&>("model");
@@ -105,10 +105,10 @@ namespace app {
}
TfLiteIntArray* inputShape = model.GetInputShape(0);
- const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx];
- const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx];
+ const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx];
+ const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
- audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength);
+ audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength);
mfcc.Init();
/* Deduce the data length required for 1 inference from the network parameters. */
@@ -132,7 +132,7 @@ namespace app {
/* We expect to be sampling 1 second worth of data at a time.
* NOTE: This is only used for time stamp calculation. */
- const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
+ const float secondsPerSample = 1.0/audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
do {
platform.data_psn->clear(COLOR_BLACK);
@@ -208,7 +208,7 @@ namespace app {
std::vector<ClassificationResult> classificationResult;
auto& classifier = ctx.Get<KwsClassifier&>("classifier");
classifier.GetClassificationResults(outputTensor, classificationResult,
- ctx.Get<std::vector<std::string>&>("labels"), 1);
+ ctx.Get<std::vector<std::string>&>("labels"), 1, true);
results.emplace_back(kws::KwsResult(classificationResult,
audioDataSlider.Index() * secondsPerSample * audioDataStride,
@@ -240,7 +240,6 @@ namespace app {
return true;
}
-
static bool PresentInferenceResult(hal_platform& platform,
const std::vector<arm::app::kws::KwsResult>& results)
{
@@ -259,7 +258,6 @@ namespace app {
std::string topKeyword{"<none>"};
float score = 0.f;
-
if (!results[i].m_resultVec.empty()) {
topKeyword = results[i].m_resultVec[0].m_label;
score = results[i].m_resultVec[0].m_normalisedVal;
@@ -366,7 +364,7 @@ namespace app {
static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
+ GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
{
std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;