diff options
Diffstat (limited to 'source/use_case/kws/src')
-rw-r--r-- | source/use_case/kws/src/MainLoop.cc | 4 | ||||
-rw-r--r-- | source/use_case/kws/src/MicroNetKwsModel.cc (renamed from source/use_case/kws/src/DsCnnModel.cc) | 11 | ||||
-rw-r--r-- | source/use_case/kws/src/UseCaseHandler.cc | 24 |
3 files changed, 18 insertions, 21 deletions
diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc index c683e71..bde246b 100644 --- a/source/use_case/kws/src/MainLoop.cc +++ b/source/use_case/kws/src/MainLoop.cc @@ -16,7 +16,7 @@ */ #include "InputFiles.hpp" /* For input audio clips. */ #include "Classifier.hpp" /* Classifier. */ -#include "DsCnnModel.hpp" /* Model class for running inference. */ +#include "MicroNetKwsModel.hpp" /* Model class for running inference. */ #include "hal.h" /* Brings in platform definitions. */ #include "Labels.hpp" /* For label strings. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ @@ -49,7 +49,7 @@ static void DisplayMenu() void main_loop(hal_platform& platform) { - arm::app::DsCnnModel model; /* Model wrapper object. */ + arm::app::MicroNetKwsModel model; /* Model wrapper object. */ /* Load the model. */ if (!model.Init()) { diff --git a/source/use_case/kws/src/DsCnnModel.cc b/source/use_case/kws/src/MicroNetKwsModel.cc index 4edfc04..48a9b8c 100644 --- a/source/use_case/kws/src/DsCnnModel.cc +++ b/source/use_case/kws/src/MicroNetKwsModel.cc @@ -14,16 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" -const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver() +const tflite::MicroOpResolver& arm::app::MicroNetKwsModel::GetOpResolver() { return this->m_opResolver; } -bool arm::app::DsCnnModel::EnlistOperations() +bool arm::app::MicroNetKwsModel::EnlistOperations() { this->m_opResolver.AddReshape(); this->m_opResolver.AddAveragePool2D(); @@ -31,7 +31,6 @@ bool arm::app::DsCnnModel::EnlistOperations() this->m_opResolver.AddDepthwiseConv2D(); this->m_opResolver.AddFullyConnected(); this->m_opResolver.AddRelu(); - this->m_opResolver.AddSoftmax(); #if defined(ARM_NPU) if (kTfLiteOk == this->m_opResolver.AddEthosU()) { @@ -46,13 +45,13 @@ bool arm::app::DsCnnModel::EnlistOperations() } extern uint8_t* GetModelPointer(); -const uint8_t* arm::app::DsCnnModel::ModelPointer() +const uint8_t* arm::app::MicroNetKwsModel::ModelPointer() { return GetModelPointer(); } extern size_t GetModelLen(); -size_t arm::app::DsCnnModel::ModelSize() +size_t arm::app::MicroNetKwsModel::ModelSize() { return GetModelLen(); }
\ No newline at end of file diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index 3d95753..8085af7 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -18,9 +18,9 @@ #include "InputFiles.hpp" #include "Classifier.hpp" -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include "AudioUtils.hpp" #include "UseCaseCommonUtils.hpp" #include "KwsResult.hpp" @@ -59,7 +59,7 @@ namespace app { * @return Function to be called providing audio sample and sliding window index. */ static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize); @@ -72,8 +72,8 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast<int>( - (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)? - arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx); + (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? + arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); auto& model = ctx.Get<Model&>("model"); @@ -105,10 +105,10 @@ namespace app { } TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx]; - const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx]; + const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; - audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength); + audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength); mfcc.Init(); /* Deduce the data length required for 1 inference from the network parameters. */ @@ -132,7 +132,7 @@ namespace app { /* We expect to be sampling 1 second worth of data at a time. * NOTE: This is only used for time stamp calculation. */ - const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq; + const float secondsPerSample = 1.0/audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; do { platform.data_psn->clear(COLOR_BLACK); @@ -208,7 +208,7 @@ namespace app { std::vector<ClassificationResult> classificationResult; auto& classifier = ctx.Get<KwsClassifier&>("classifier"); classifier.GetClassificationResults(outputTensor, classificationResult, - ctx.Get<std::vector<std::string>&>("labels"), 1); + ctx.Get<std::vector<std::string>&>("labels"), 1, true); results.emplace_back(kws::KwsResult(classificationResult, audioDataSlider.Index() * secondsPerSample * audioDataStride, @@ -240,7 +240,6 @@ namespace app { return true; } - static bool PresentInferenceResult(hal_platform& platform, const std::vector<arm::app::kws::KwsResult>& results) { @@ -259,7 +258,6 @@ namespace app { std::string topKeyword{"<none>"}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { topKeyword = results[i].m_resultVec[0].m_label; score = results[i].m_resultVec[0].m_normalisedVal; @@ -366,7 +364,7 @@ namespace app { static std::function<void (std::vector<int16_t>&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) { std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc; |