summaryrefslogtreecommitdiff
path: root/source/use_case/kws_asr
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/kws_asr')
-rw-r--r--source/use_case/kws_asr/include/AsrClassifier.hpp4
-rw-r--r--source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp (renamed from source/use_case/kws_asr/include/DsCnnMfcc.hpp)16
-rw-r--r--source/use_case/kws_asr/include/MicroNetKwsModel.hpp (renamed from source/use_case/kws_asr/include/DsCnnModel.hpp)15
-rw-r--r--source/use_case/kws_asr/src/AsrClassifier.cc3
-rw-r--r--source/use_case/kws_asr/src/MainLoop.cc10
-rw-r--r--source/use_case/kws_asr/src/MicroNetKwsModel.cc (renamed from source/use_case/kws_asr/src/DsCnnModel.cc)13
-rw-r--r--source/use_case/kws_asr/src/UseCaseHandler.cc20
-rw-r--r--source/use_case/kws_asr/usecase.cmake8
8 files changed, 44 insertions, 45 deletions
diff --git a/source/use_case/kws_asr/include/AsrClassifier.hpp b/source/use_case/kws_asr/include/AsrClassifier.hpp
index 7dbb6e9..6ab9685 100644
--- a/source/use_case/kws_asr/include/AsrClassifier.hpp
+++ b/source/use_case/kws_asr/include/AsrClassifier.hpp
@@ -32,12 +32,14 @@ namespace app {
* populated by this function.
* @param[in] labels Labels vector to match classified classes
* @param[in] topNCount Number of top classifications to pick.
+ * @param[in] use_softmax Whether softmax scaling should be applied to model output.
* @return true if successful, false otherwise.
**/
bool GetClassificationResults(
TfLiteTensor* outputTensor,
std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, uint32_t topNCount) override;
+ const std::vector <std::string>& labels, uint32_t topNCount,
+ bool use_softmax = false) override;
private:
diff --git a/source/use_case/kws_asr/include/DsCnnMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
index c97dd9d..43bd390 100644
--- a/source/use_case/kws_asr/include/DsCnnMfcc.hpp
+++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
@@ -14,8 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef KWS_ASR_DSCNN_MFCC_HPP
-#define KWS_ASR_DSCNN_MFCC_HPP
+#ifndef KWS_ASR_MICRONET_MFCC_HPP
+#define KWS_ASR_MICRONET_MFCC_HPP
#include "Mfcc.hpp"
@@ -23,8 +23,8 @@ namespace arm {
namespace app {
namespace audio {
- /* Class to provide DS-CNN specific MFCC calculation requirements. */
- class DsCnnMFCC : public MFCC {
+ /* Class to provide MicroNet specific MFCC calculation requirements. */
+ class MicroNetMFCC : public MFCC {
public:
static constexpr uint32_t ms_defaultSamplingFreq = 16000;
@@ -34,18 +34,18 @@ namespace audio {
static constexpr bool ms_defaultUseHtkMethod = true;
- explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen)
+ explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen)
: MFCC(MfccParams(
ms_defaultSamplingFreq, ms_defaultNumFbankBins,
ms_defaultMelLoFreq, ms_defaultMelHiFreq,
numFeats, frameLen, ms_defaultUseHtkMethod))
{}
- DsCnnMFCC() = delete;
- ~DsCnnMFCC() = default;
+ MicroNetMFCC() = delete;
+ ~MicroNetMFCC() = default;
};
} /* namespace audio */
} /* namespace app */
} /* namespace arm */
-#endif /* KWS_ASR_DSCNN_MFCC_HPP */
+#endif /* KWS_ASR_MICRONET_MFCC_HPP */
diff --git a/source/use_case/kws_asr/include/DsCnnModel.hpp b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp
index 92d96b9..22cf916 100644
--- a/source/use_case/kws_asr/include/DsCnnModel.hpp
+++ b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp
@@ -14,8 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef KWS_ASR_DSCNNMODEL_HPP
-#define KWS_ASR_DSCNNMODEL_HPP
+#ifndef KWS_ASR_MICRONETMODEL_HPP
+#define KWS_ASR_MICRONETMODEL_HPP
#include "Model.hpp"
@@ -33,12 +33,11 @@ namespace kws {
namespace arm {
namespace app {
-
- class DsCnnModel : public Model {
+ class MicroNetKwsModel : public Model {
public:
/* Indices for the expected model - based on input and output tensor shapes */
- static constexpr uint32_t ms_inputRowsIdx = 2;
- static constexpr uint32_t ms_inputColsIdx = 3;
+ static constexpr uint32_t ms_inputRowsIdx = 1;
+ static constexpr uint32_t ms_inputColsIdx = 2;
static constexpr uint32_t ms_outputRowsIdx = 2;
static constexpr uint32_t ms_outputColsIdx = 3;
@@ -55,7 +54,7 @@ namespace app {
private:
/* Maximum number of individual operations that can be enlisted. */
- static constexpr int ms_maxOpCnt = 10;
+ static constexpr int ms_maxOpCnt = 7;
/* A mutable op resolver instance. */
tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver;
@@ -64,4 +63,4 @@ namespace app {
} /* namespace app */
} /* namespace arm */
-#endif /* KWS_DSCNNMODEL_HPP */
+#endif /* KWS_ASR_MICRONETMODEL_HPP */
diff --git a/source/use_case/kws_asr/src/AsrClassifier.cc b/source/use_case/kws_asr/src/AsrClassifier.cc
index 57d5058..3f9cd7b 100644
--- a/source/use_case/kws_asr/src/AsrClassifier.cc
+++ b/source/use_case/kws_asr/src/AsrClassifier.cc
@@ -73,8 +73,9 @@ template bool arm::app::AsrClassifier::GetTopResults<int8_t>(TfLiteTensor* tenso
bool arm::app::AsrClassifier::GetClassificationResults(
TfLiteTensor* outputTensor,
std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, uint32_t topNCount)
+ const std::vector <std::string>& labels, uint32_t topNCount, bool use_softmax)
{
+ UNUSED(use_softmax);
vecResults.clear();
constexpr int minTensorDims = static_cast<int>(
diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc
index d5a2c2b..30cb084 100644
--- a/source/use_case/kws_asr/src/MainLoop.cc
+++ b/source/use_case/kws_asr/src/MainLoop.cc
@@ -16,11 +16,11 @@
*/
#include "hal.h" /* Brings in platform definitions. */
#include "InputFiles.hpp" /* For input images. */
-#include "Labels_dscnn.hpp" /* For DS-CNN label strings. */
+#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */
#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */
#include "Classifier.hpp" /* KWS classifier. */
#include "AsrClassifier.hpp" /* ASR classifier. */
-#include "DsCnnModel.hpp" /* KWS model class for running inference. */
+#include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */
#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
@@ -69,7 +69,7 @@ static uint32_t GetOutputInnerLen(const arm::app::Model& model,
void main_loop(hal_platform& platform)
{
/* Model wrapper objects. */
- arm::app::DsCnnModel kwsModel;
+ arm::app::MicroNetKwsModel kwsModel;
arm::app::Wav2LetterModel asrModel;
/* Load the models. */
@@ -81,7 +81,7 @@ void main_loop(hal_platform& platform)
/* Initialise the asr model using the same allocator from KWS
* to re-use the tensor arena. */
if (!asrModel.Init(kwsModel.GetAllocator())) {
- printf_err("Failed to initalise ASR model\n");
+ printf_err("Failed to initialise ASR model\n");
return;
}
@@ -137,7 +137,7 @@ void main_loop(hal_platform& platform)
caseContext.Set<const std::vector <std::string>&>("kwslabels", kwsLabels);
/* Index of the kws outputs we trigger ASR on. */
- caseContext.Set<uint32_t>("keywordindex", 2);
+ caseContext.Set<uint32_t>("keywordindex", 9 );
/* Loop. */
bool executionSuccessful = true;
diff --git a/source/use_case/kws_asr/src/DsCnnModel.cc b/source/use_case/kws_asr/src/MicroNetKwsModel.cc
index 71d4ceb..4b44580 100644
--- a/source/use_case/kws_asr/src/DsCnnModel.cc
+++ b/source/use_case/kws_asr/src/MicroNetKwsModel.cc
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include "DsCnnModel.hpp"
+#include "MicroNetKwsModel.hpp"
#include "hal.h"
@@ -27,21 +27,18 @@ namespace kws {
} /* namespace app */
} /* namespace arm */
-const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver()
+const tflite::MicroOpResolver& arm::app::MicroNetKwsModel::GetOpResolver()
{
return this->m_opResolver;
}
-bool arm::app::DsCnnModel::EnlistOperations()
+bool arm::app::MicroNetKwsModel::EnlistOperations()
{
this->m_opResolver.AddAveragePool2D();
this->m_opResolver.AddConv2D();
this->m_opResolver.AddDepthwiseConv2D();
this->m_opResolver.AddFullyConnected();
this->m_opResolver.AddRelu();
- this->m_opResolver.AddSoftmax();
- this->m_opResolver.AddQuantize();
- this->m_opResolver.AddDequantize();
this->m_opResolver.AddReshape();
#if defined(ARM_NPU)
@@ -56,12 +53,12 @@ bool arm::app::DsCnnModel::EnlistOperations()
return true;
}
-const uint8_t* arm::app::DsCnnModel::ModelPointer()
+const uint8_t* arm::app::MicroNetKwsModel::ModelPointer()
{
return arm::app::kws::GetModelPointer();
}
-size_t arm::app::DsCnnModel::ModelSize()
+size_t arm::app::MicroNetKwsModel::ModelSize()
{
return arm::app::kws::GetModelLen();
} \ No newline at end of file
diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc
index 1d88ba1..c67be22 100644
--- a/source/use_case/kws_asr/src/UseCaseHandler.cc
+++ b/source/use_case/kws_asr/src/UseCaseHandler.cc
@@ -20,8 +20,8 @@
#include "InputFiles.hpp"
#include "AudioUtils.hpp"
#include "UseCaseCommonUtils.hpp"
-#include "DsCnnModel.hpp"
-#include "DsCnnMfcc.hpp"
+#include "MicroNetKwsModel.hpp"
+#include "MicroNetKwsMfcc.hpp"
#include "Classifier.hpp"
#include "KwsResult.hpp"
#include "Wav2LetterMfcc.hpp"
@@ -77,12 +77,12 @@ namespace app {
*
* @param[in] mfcc MFCC feature calculator.
* @param[in,out] inputTensor Input tensor pointer to store calculated features.
- * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
+ * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
*
* @return function function to be called providing audio sample and sliding window index.
**/
static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
+ GetFeatureCalculator(audio::MicroNetMFCC& mfcc,
TfLiteTensor* inputTensor,
size_t cacheSize);
@@ -98,8 +98,8 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartY = 40;
constexpr int minTensorDims = static_cast<int>(
- (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
- arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
+ (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
+ arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
KWSOutput output;
@@ -128,7 +128,7 @@ namespace app {
const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
- audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
+ audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength);
kwsMfcc.Init();
/* Deduce the data length required for 1 KWS inference from the network parameters. */
@@ -152,7 +152,7 @@ namespace app {
/* We expect to be sampling 1 second worth of data at a time
* NOTE: This is only used for time stamp calculation. */
- const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
+ const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq;
auto currentIndex = ctx.Get<uint32_t>("clipIndex");
@@ -230,7 +230,7 @@ namespace app {
kwsClassifier.GetClassificationResults(
kwsOutputTensor, kwsClassificationResult,
- ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
+ ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true);
kwsResults.emplace_back(
kws::KwsResult(
@@ -604,7 +604,7 @@ namespace app {
static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
+ GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
{
std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake
index d8629b6..b3fe020 100644
--- a/source/use_case/kws_asr/usecase.cmake
+++ b/source/use_case/kws_asr/usecase.cmake
@@ -45,7 +45,7 @@ USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples
# Generate kws labels file:
USER_OPTION(${use_case}_LABELS_TXT_FILE_KWS "Labels' txt file for the chosen model."
- ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/ds_cnn_labels.txt
+ ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/micronet_kws_labels.txt
FILEPATH)
# Generate asr labels file:
@@ -67,10 +67,10 @@ USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [
STRING)
if (ETHOS_U_NPU_ENABLED)
- set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite)
+ set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/kws_micronet_m_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite)
set(DEFAULT_MODEL_PATH_ASR ${DEFAULT_MODEL_DIR}/wav2letter_pruned_int8_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite)
else()
- set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8.tflite)
+ set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/kws_micronet_m.tflite)
set(DEFAULT_MODEL_PATH_ASR ${DEFAULT_MODEL_DIR}/wav2letter_pruned_int8.tflite)
endif()
@@ -134,7 +134,7 @@ generate_labels_code(
INPUT "${${use_case}_LABELS_TXT_FILE_KWS}"
DESTINATION_SRC ${SRC_GEN_DIR}
DESTINATION_HDR ${INC_GEN_DIR}
- OUTPUT_FILENAME "Labels_dscnn"
+ OUTPUT_FILENAME "Labels_micronetkws"
NAMESPACE "arm" "app" "kws"
)