diff options
Diffstat (limited to 'source/use_case/kws')
-rw-r--r-- | source/use_case/kws/include/MicroNetKwsMfcc.hpp (renamed from source/use_case/kws/include/DsCnnMfcc.hpp) | 16 | ||||
-rw-r--r-- | source/use_case/kws/include/MicroNetKwsModel.hpp (renamed from source/use_case/kws/include/DsCnnModel.hpp) | 14 | ||||
-rw-r--r-- | source/use_case/kws/src/MainLoop.cc | 4 | ||||
-rw-r--r-- | source/use_case/kws/src/MicroNetKwsModel.cc (renamed from source/use_case/kws/src/DsCnnModel.cc) | 11 | ||||
-rw-r--r-- | source/use_case/kws/src/UseCaseHandler.cc | 24 | ||||
-rw-r--r-- | source/use_case/kws/usecase.cmake | 9 |
6 files changed, 38 insertions, 40 deletions
diff --git a/source/use_case/kws/include/DsCnnMfcc.hpp b/source/use_case/kws/include/MicroNetKwsMfcc.hpp index 3f681af..b2565a3 100644 --- a/source/use_case/kws/include/DsCnnMfcc.hpp +++ b/source/use_case/kws/include/MicroNetKwsMfcc.hpp @@ -14,8 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_DSCNN_MFCC_HPP -#define KWS_DSCNN_MFCC_HPP +#ifndef KWS_MICRONET_MFCC_HPP +#define KWS_MICRONET_MFCC_HPP #include "Mfcc.hpp" @@ -23,8 +23,8 @@ namespace arm { namespace app { namespace audio { - /* Class to provide DS-CNN specific MFCC calculation requirements. */ - class DsCnnMFCC : public MFCC { + /* Class to provide MicroNet specific MFCC calculation requirements. */ + class MicroNetKwsMFCC : public MFCC { public: static constexpr uint32_t ms_defaultSamplingFreq = 16000; @@ -33,18 +33,18 @@ namespace audio { static constexpr uint32_t ms_defaultMelHiFreq = 4000; static constexpr bool ms_defaultUseHtkMethod = true; - explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen) + explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen) : MFCC(MfccParams( ms_defaultSamplingFreq, ms_defaultNumFbankBins, ms_defaultMelLoFreq, ms_defaultMelHiFreq, numFeats, frameLen, ms_defaultUseHtkMethod)) {} - DsCnnMFCC() = delete; - ~DsCnnMFCC() = default; + MicroNetKwsMFCC() = delete; + ~MicroNetKwsMFCC() = default; }; } /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_DSCNN_MFCC_HPP */
\ No newline at end of file +#endif /* KWS_MICRONET_MFCC_HPP */
\ No newline at end of file diff --git a/source/use_case/kws/include/DsCnnModel.hpp b/source/use_case/kws/include/MicroNetKwsModel.hpp index a1a45cd..3259c45 100644 --- a/source/use_case/kws/include/DsCnnModel.hpp +++ b/source/use_case/kws/include/MicroNetKwsModel.hpp @@ -14,8 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_DSCNNMODEL_HPP -#define KWS_DSCNNMODEL_HPP +#ifndef KWS_MICRONETMODEL_HPP +#define KWS_MICRONETMODEL_HPP #include "Model.hpp" @@ -26,11 +26,11 @@ extern const float g_ScoreThreshold; namespace arm { namespace app { - class DsCnnModel : public Model { + class MicroNetKwsModel : public Model { public: /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 2; - static constexpr uint32_t ms_inputColsIdx = 3; + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; @@ -47,7 +47,7 @@ namespace app { private: /* Maximum number of individual operations that can be enlisted. */ - static constexpr int ms_maxOpCnt = 8; + static constexpr int ms_maxOpCnt = 7; /* A mutable op resolver instance. */ tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver; @@ -56,4 +56,4 @@ namespace app { } /* namespace app */ } /* namespace arm */ -#endif /* KWS_DSCNNMODEL_HPP */ +#endif /* KWS_MICRONETMODEL_HPP */ diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc index c683e71..bde246b 100644 --- a/source/use_case/kws/src/MainLoop.cc +++ b/source/use_case/kws/src/MainLoop.cc @@ -16,7 +16,7 @@ */ #include "InputFiles.hpp" /* For input audio clips. */ #include "Classifier.hpp" /* Classifier. */ -#include "DsCnnModel.hpp" /* Model class for running inference. */ +#include "MicroNetKwsModel.hpp" /* Model class for running inference. */ #include "hal.h" /* Brings in platform definitions. */ #include "Labels.hpp" /* For label strings. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ @@ -49,7 +49,7 @@ static void DisplayMenu() void main_loop(hal_platform& platform) { - arm::app::DsCnnModel model; /* Model wrapper object. */ + arm::app::MicroNetKwsModel model; /* Model wrapper object. */ /* Load the model. */ if (!model.Init()) { diff --git a/source/use_case/kws/src/DsCnnModel.cc b/source/use_case/kws/src/MicroNetKwsModel.cc index 4edfc04..48a9b8c 100644 --- a/source/use_case/kws/src/DsCnnModel.cc +++ b/source/use_case/kws/src/MicroNetKwsModel.cc @@ -14,16 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" -const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver() +const tflite::MicroOpResolver& arm::app::MicroNetKwsModel::GetOpResolver() { return this->m_opResolver; } -bool arm::app::DsCnnModel::EnlistOperations() +bool arm::app::MicroNetKwsModel::EnlistOperations() { this->m_opResolver.AddReshape(); this->m_opResolver.AddAveragePool2D(); @@ -31,7 +31,6 @@ bool arm::app::DsCnnModel::EnlistOperations() this->m_opResolver.AddDepthwiseConv2D(); this->m_opResolver.AddFullyConnected(); this->m_opResolver.AddRelu(); - this->m_opResolver.AddSoftmax(); #if defined(ARM_NPU) if (kTfLiteOk == this->m_opResolver.AddEthosU()) { @@ -46,13 +45,13 @@ bool arm::app::DsCnnModel::EnlistOperations() } extern uint8_t* GetModelPointer(); -const uint8_t* arm::app::DsCnnModel::ModelPointer() +const uint8_t* arm::app::MicroNetKwsModel::ModelPointer() { return GetModelPointer(); } extern size_t GetModelLen(); -size_t arm::app::DsCnnModel::ModelSize() +size_t arm::app::MicroNetKwsModel::ModelSize() { return GetModelLen(); }
\ No newline at end of file diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index 3d95753..8085af7 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -18,9 +18,9 @@ #include "InputFiles.hpp" #include "Classifier.hpp" -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include "AudioUtils.hpp" #include "UseCaseCommonUtils.hpp" #include "KwsResult.hpp" @@ -59,7 +59,7 @@ namespace app { * @return Function to be called providing audio sample and sliding window index. */ static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize); @@ -72,8 +72,8 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast<int>( - (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)? - arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx); + (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? + arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); auto& model = ctx.Get<Model&>("model"); @@ -105,10 +105,10 @@ namespace app { } TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx]; - const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx]; + const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; - audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength); + audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength); mfcc.Init(); /* Deduce the data length required for 1 inference from the network parameters. */ @@ -132,7 +132,7 @@ namespace app { /* We expect to be sampling 1 second worth of data at a time. * NOTE: This is only used for time stamp calculation. */ - const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq; + const float secondsPerSample = 1.0/audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; do { platform.data_psn->clear(COLOR_BLACK); @@ -208,7 +208,7 @@ namespace app { std::vector<ClassificationResult> classificationResult; auto& classifier = ctx.Get<KwsClassifier&>("classifier"); classifier.GetClassificationResults(outputTensor, classificationResult, - ctx.Get<std::vector<std::string>&>("labels"), 1); + ctx.Get<std::vector<std::string>&>("labels"), 1, true); results.emplace_back(kws::KwsResult(classificationResult, audioDataSlider.Index() * secondsPerSample * audioDataStride, @@ -240,7 +240,6 @@ namespace app { return true; } - static bool PresentInferenceResult(hal_platform& platform, const std::vector<arm::app::kws::KwsResult>& results) { @@ -259,7 +258,6 @@ namespace app { std::string topKeyword{"<none>"}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { topKeyword = results[i].m_resultVec[0].m_label; score = results[i].m_resultVec[0].m_normalisedVal; @@ -366,7 +364,7 @@ namespace app { static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) { std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; diff --git a/source/use_case/kws/usecase.cmake b/source/use_case/kws/usecase.cmake index 34e39e4..9f3736e 100644 --- a/source/use_case/kws/usecase.cmake +++ b/source/use_case/kws/usecase.cmake @@ -20,7 +20,7 @@ USER_OPTION(${use_case}_FILE_PATH "Directory with custom WAV input files, or pat PATH_OR_FILE) USER_OPTION(${use_case}_LABELS_TXT_FILE "Labels' txt file for the chosen model." - ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/ds_cnn_labels.txt + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/micronet_kws_labels.txt FILEPATH) USER_OPTION(${use_case}_AUDIO_RATE "Specify the target sampling rate. Default is 16000." @@ -48,7 +48,7 @@ USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples STRING) USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD "Specify the score threshold [0.0, 1.0) that must be applied to the inference results for a label to be deemed valid." - 0.9 + 0.7 STRING) # Generate input files @@ -73,10 +73,11 @@ USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen 0x00100000 STRING) + if (ETHOS_U_NPU_ENABLED) - set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) + set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/kws_micronet_m_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) else() - set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8.tflite) + set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/kws_micronet_m.tflite) endif() set(EXTRA_MODEL_CODE |