diff options
Diffstat (limited to 'source/use_case/kws_asr/include')
-rw-r--r-- | source/use_case/kws_asr/include/AsrClassifier.hpp | 4 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp (renamed from source/use_case/kws_asr/include/DsCnnMfcc.hpp) | 16 | ||||
-rw-r--r-- | source/use_case/kws_asr/include/MicroNetKwsModel.hpp (renamed from source/use_case/kws_asr/include/DsCnnModel.hpp) | 15 |
3 files changed, 18 insertions, 17 deletions
diff --git a/source/use_case/kws_asr/include/AsrClassifier.hpp b/source/use_case/kws_asr/include/AsrClassifier.hpp index 7dbb6e9..6ab9685 100644 --- a/source/use_case/kws_asr/include/AsrClassifier.hpp +++ b/source/use_case/kws_asr/include/AsrClassifier.hpp @@ -32,12 +32,14 @@ namespace app { * populated by this function. * @param[in] labels Labels vector to match classified classes * @param[in] topNCount Number of top classifications to pick. + * @param[in] use_softmax Whether softmax scaling should be applied to model output. * @return true if successful, false otherwise. **/ bool GetClassificationResults( TfLiteTensor* outputTensor, std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount) override; + const std::vector <std::string>& labels, uint32_t topNCount, + bool use_softmax = false) override; private: diff --git a/source/use_case/kws_asr/include/DsCnnMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp index c97dd9d..43bd390 100644 --- a/source/use_case/kws_asr/include/DsCnnMfcc.hpp +++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp @@ -14,8 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_DSCNN_MFCC_HPP -#define KWS_ASR_DSCNN_MFCC_HPP +#ifndef KWS_ASR_MICRONET_MFCC_HPP +#define KWS_ASR_MICRONET_MFCC_HPP #include "Mfcc.hpp" @@ -23,8 +23,8 @@ namespace arm { namespace app { namespace audio { - /* Class to provide DS-CNN specific MFCC calculation requirements. */ - class DsCnnMFCC : public MFCC { + /* Class to provide MicroNet specific MFCC calculation requirements. */ + class MicroNetMFCC : public MFCC { public: static constexpr uint32_t ms_defaultSamplingFreq = 16000; @@ -34,18 +34,18 @@ namespace audio { static constexpr bool ms_defaultUseHtkMethod = true; - explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen) + explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen) : MFCC(MfccParams( ms_defaultSamplingFreq, ms_defaultNumFbankBins, ms_defaultMelLoFreq, ms_defaultMelHiFreq, numFeats, frameLen, ms_defaultUseHtkMethod)) {} - DsCnnMFCC() = delete; - ~DsCnnMFCC() = default; + MicroNetMFCC() = delete; + ~MicroNetMFCC() = default; }; } /* namespace audio */ } /* namespace app */ } /* namespace arm */ -#endif /* KWS_ASR_DSCNN_MFCC_HPP */ +#endif /* KWS_ASR_MICRONET_MFCC_HPP */ diff --git a/source/use_case/kws_asr/include/DsCnnModel.hpp b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp index 92d96b9..22cf916 100644 --- a/source/use_case/kws_asr/include/DsCnnModel.hpp +++ b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp @@ -14,8 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef KWS_ASR_DSCNNMODEL_HPP -#define KWS_ASR_DSCNNMODEL_HPP +#ifndef KWS_ASR_MICRONETMODEL_HPP +#define KWS_ASR_MICRONETMODEL_HPP #include "Model.hpp" @@ -33,12 +33,11 @@ namespace kws { namespace arm { namespace app { - - class DsCnnModel : public Model { + class MicroNetKwsModel : public Model { public: /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 2; - static constexpr uint32_t ms_inputColsIdx = 3; + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; static constexpr uint32_t ms_outputRowsIdx = 2; static constexpr uint32_t ms_outputColsIdx = 3; @@ -55,7 +54,7 @@ namespace app { private: /* Maximum number of individual operations that can be enlisted. */ - static constexpr int ms_maxOpCnt = 10; + static constexpr int ms_maxOpCnt = 7; /* A mutable op resolver instance. */ tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver; @@ -64,4 +63,4 @@ namespace app { } /* namespace app */ } /* namespace arm */ -#endif /* KWS_DSCNNMODEL_HPP */ +#endif /* KWS_ASR_MICRONETMODEL_HPP */ |