diff options
Diffstat (limited to 'source/use_case/kws_asr')
-rw-r--r-- | source/use_case/kws_asr/include/AsrClassifier.hpp | 4 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp (renamed from source/use_case/kws_asr/include/DsCnnMfcc.hpp) | 16 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/MicroNetKwsModel.hpp (renamed from source/use_case/kws_asr/include/DsCnnModel.hpp) | 15 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/AsrClassifier.cc | 3 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/MainLoop.cc | 10 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/MicroNetKwsModel.cc (renamed from source/use_case/kws_asr/src/DsCnnModel.cc) | 13 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/UseCaseHandler.cc | 20 | ||||
-rw-r--r-- | source/use_case/kws_asr/usecase.cmake | 8 |
8 files changed, 44 insertions, 45 deletions
diff --git a/source/use_case/kws_asr/include/AsrClassifier.hpp b/source/use_case/kws_asr/include/AsrClassifier.hpp index 7dbb6e9..6ab9685 100644 --- a/source/use_case/kws_asr/include/AsrClassifier.hpp +++ b/source/use_case/kws_asr/include/AsrClassifier.hpp @@ -32,12 +32,14 @@ namespace app { * populated by this function. * @param[in] labels Labels vector to match classified classes * @param[in] topNCount Number of top classifications to pick. + * @param[in] use_softmax Whether softmax scaling should be applied to model output. * @return true if successful, false otherwise. **/ bool GetClassificationResults( TfLiteTensor* outputTensor, std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount) override; + const std::vector <std::string>& labels, uint32_t topNCount, + bool use_softmax = false) override; private: diff --git a/source/use_case/kws_asr/include/DsCnnMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp index c97dd9d..43bd390 100644 --- a/source/use_case/kws_asr/include/DsCnnMfcc.hpp +++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp @@ -14,8 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_DSCNN_MFCC_HPP -#define KWS_ASR_DSCNN_MFCC_HPP +#ifndef KWS_ASR_MICRONET_MFCC_HPP +#define KWS_ASR_MICRONET_MFCC_HPP #include "Mfcc.hpp" @@ -23,8 +23,8 @@ namespace arm { namespace app { namespace audio { - /* Class to provide DS-CNN specific MFCC calculation requirements. */ - class DsCnnMFCC : public MFCC { + /* Class to provide MicroNet specific MFCC calculation requirements. */ + class MicroNetMFCC : public MFCC { public: static constexpr uint32_t ms_defaultSamplingFreq = 16000; @@ -34,18 +34,18 @@ namespace audio { static constexpr bool ms_defaultUseHtkMethod = true; - explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen) + explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen) : MFCC(MfccParams( ms_defaultSamplingFreq, ms_defaultNumFbankBins, ms_defaultMelLoFreq, ms_defaultMelHiFreq, numFeats, frameLen, ms_defaultUseHtkMethod)) {} - DsCnnMFCC() = delete; - ~DsCnnMFCC() = default; + MicroNetMFCC() = delete; + ~MicroNetMFCC() = default; }; } /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_ASR_DSCNN_MFCC_HPP */ +#endif /* KWS_ASR_MICRONET_MFCC_HPP */ diff --git a/source/use_case/kws_asr/include/DsCnnModel.hpp b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp index 92d96b9..22cf916 100644 --- a/source/use_case/kws_asr/include/DsCnnModel.hpp +++ b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp @@ -14,8 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_DSCNNMODEL_HPP -#define KWS_ASR_DSCNNMODEL_HPP +#ifndef KWS_ASR_MICRONETMODEL_HPP +#define KWS_ASR_MICRONETMODEL_HPP #include "Model.hpp" @@ -33,12 +33,11 @@ namespace kws { namespace arm { namespace app { - - class DsCnnModel : public Model { + class MicroNetKwsModel : public Model { public: /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 2; - static constexpr uint32_t ms_inputColsIdx = 3; + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; @@ -55,7 +54,7 @@ namespace app { private: /* Maximum number of individual operations that can be enlisted. */ - static constexpr int ms_maxOpCnt = 10; + static constexpr int ms_maxOpCnt = 7; /* A mutable op resolver instance. */ tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver; @@ -64,4 +63,4 @@ namespace app { } /* namespace app */ } /* namespace arm */ -#endif /* KWS_DSCNNMODEL_HPP */ +#endif /* KWS_ASR_MICRONETMODEL_HPP */ diff --git a/source/use_case/kws_asr/src/AsrClassifier.cc b/source/use_case/kws_asr/src/AsrClassifier.cc index 57d5058..3f9cd7b 100644 --- a/source/use_case/kws_asr/src/AsrClassifier.cc +++ b/source/use_case/kws_asr/src/AsrClassifier.cc @@ -73,8 +73,9 @@ template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tenso bool arm::app::AsrClassifier::GetClassificationResults( TfLiteTensor* outputTensor, std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount) + const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax) { + UNUSED(use_softmax); vecResults.clear(); constexpr int minTensorDims = static_cast<int>( diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc index d5a2c2b..30cb084 100644 --- a/source/use_case/kws_asr/src/MainLoop.cc +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -16,11 +16,11 @@ */ #include "hal.h" /* Brings in platform definitions. */ #include "InputFiles.hpp" /* For input images. */ -#include "Labels_dscnn.hpp" /* For DS-CNN label strings. */ +#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */ #include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ #include "Classifier.hpp" /* KWS classifier. */ #include "AsrClassifier.hpp" /* ASR classifier. */ -#include "DsCnnModel.hpp" /* KWS model class for running inference. */ +#include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */ #include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ @@ -69,7 +69,7 @@ static uint32_t GetOutputInnerLen(const arm::app::Model& model, void main_loop(hal_platform& platform) { /* Model wrapper objects. */ - arm::app::DsCnnModel kwsModel; + arm::app::MicroNetKwsModel kwsModel; arm::app::Wav2LetterModel asrModel; /* Load the models. */ @@ -81,7 +81,7 @@ void main_loop(hal_platform& platform) /* Initialise the asr model using the same allocator from KWS * to re-use the tensor arena. */ if (!asrModel.Init(kwsModel.GetAllocator())) { - printf_err("Failed to initalise ASR model\n"); + printf_err("Failed to initialise ASR model\n"); return; } @@ -137,7 +137,7 @@ void main_loop(hal_platform& platform) caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels); /* Index of the kws outputs we trigger ASR on. */ - caseContext.Set<uint32_t>("keywordindex", 2); + caseContext.Set<uint32_t>("keywordindex", 9 ); /* Loop. */ bool executionSuccessful = true; diff --git a/source/use_case/kws_asr/src/DsCnnModel.cc b/source/use_case/kws_asr/src/MicroNetKwsModel.cc index 71d4ceb..4b44580 100644 --- a/source/use_case/kws_asr/src/DsCnnModel.cc +++ b/source/use_case/kws_asr/src/MicroNetKwsModel.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" @@ -27,21 +27,18 @@ namespace kws { } /* namespace app */ } /* namespace arm */ -const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver() +const tflite::MicroOpResolver& arm::app::MicroNetKwsModel::GetOpResolver() { return this->m_opResolver; } -bool arm::app::DsCnnModel::EnlistOperations() +bool arm::app::MicroNetKwsModel::EnlistOperations() { this->m_opResolver.AddAveragePool2D(); this->m_opResolver.AddConv2D(); this->m_opResolver.AddDepthwiseConv2D(); this->m_opResolver.AddFullyConnected(); this->m_opResolver.AddRelu(); - this->m_opResolver.AddSoftmax(); - this->m_opResolver.AddQuantize(); - this->m_opResolver.AddDequantize(); this->m_opResolver.AddReshape(); #if defined(ARM_NPU) @@ -56,12 +53,12 @@ bool arm::app::DsCnnModel::EnlistOperations() return true; } -const uint8_t* arm::app::DsCnnModel::ModelPointer() +const uint8_t* arm::app::MicroNetKwsModel::ModelPointer() { return arm::app::kws::GetModelPointer(); } -size_t arm::app::DsCnnModel::ModelSize() +size_t arm::app::MicroNetKwsModel::ModelSize() { return arm::app::kws::GetModelLen(); }
\ No newline at end of file diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 1d88ba1..c67be22 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -20,8 +20,8 @@ #include "InputFiles.hpp" #include "AudioUtils.hpp" #include "UseCaseCommonUtils.hpp" -#include "DsCnnModel.hpp" -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsModel.hpp" +#include "MicroNetKwsMfcc.hpp" #include "Classifier.hpp" #include "KwsResult.hpp" #include "Wav2LetterMfcc.hpp" @@ -77,12 +77,12 @@ namespace app { * * @param[in] mfcc MFCC feature calculator. * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors). + * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). * * @return function function to be called providing audio sample and sliding window index. **/ static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, + GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize); @@ -98,8 +98,8 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast<int>( - (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)? - arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx); + (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? + arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); KWSOutput output; @@ -128,7 +128,7 @@ namespace app { const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc"); const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins"); - audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength); + audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength); kwsMfcc.Init(); /* Deduce the data length required for 1 KWS inference from the network parameters. */ @@ -152,7 +152,7 @@ namespace app { /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq; auto currentIndex = ctx.Get<uint32_t>("clipIndex"); @@ -230,7 +230,7 @@ namespace app { kwsClassifier.GetClassificationResults( kwsOutputTensor, kwsClassificationResult, - ctx.Get<std::vector<std::string>&>("kwslabels"), 1); + ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true); kwsResults.emplace_back( kws::KwsResult( @@ -604,7 +604,7 @@ namespace app { static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) { std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake index d8629b6..b3fe020 100644 --- a/source/use_case/kws_asr/usecase.cmake +++ b/source/use_case/kws_asr/usecase.cmake @@ -45,7 +45,7 @@ USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples # Generate kws labels file: USER_OPTION(${use_case}_LABELS_TXT_FILE_KWS "Labels' txt file for the chosen model." - ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/ds_cnn_labels.txt + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/micronet_kws_labels.txt FILEPATH) # Generate asr labels file: @@ -67,10 +67,10 @@ USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [ STRING) if (ETHOS_U_NPU_ENABLED) - set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) + set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/kws_micronet_m_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) set(DEFAULT_MODEL_PATH_ASR ${DEFAULT_MODEL_DIR}/wav2letter_pruned_int8_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) else() - set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8.tflite) + set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/kws_micronet_m.tflite) set(DEFAULT_MODEL_PATH_ASR ${DEFAULT_MODEL_DIR}/wav2letter_pruned_int8.tflite) endif() @@ -134,7 +134,7 @@ generate_labels_code( INPUT "${${use_case}_LABELS_TXT_FILE_KWS}" DESTINATION_SRC ${SRC_GEN_DIR} DESTINATION_HDR ${INC_GEN_DIR} - OUTPUT_FILENAME "Labels_dscnn" + OUTPUT_FILENAME "Labels_micronetkws" NAMESPACE "arm" "app" "kws" ) |