summaryrefslogtreecommitdiff
path: root/source/use_case/kws_asr/include
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/kws_asr/include')
-rw-r--r--source/use_case/kws_asr/include/AsrClassifier.hpp4
-rw-r--r--source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp (renamed from source/use_case/kws_asr/include/DsCnnMfcc.hpp)16
-rw-r--r--source/use_case/kws_asr/include/MicroNetKwsModel.hpp (renamed from source/use_case/kws_asr/include/DsCnnModel.hpp)15
3 files changed, 18 insertions, 17 deletions
diff --git a/source/use_case/kws_asr/include/AsrClassifier.hpp b/source/use_case/kws_asr/include/AsrClassifier.hpp
index 7dbb6e9..6ab9685 100644
--- a/source/use_case/kws_asr/include/AsrClassifier.hpp
+++ b/source/use_case/kws_asr/include/AsrClassifier.hpp
@@ -32,12 +32,14 @@ namespace app {
* populated by this function.
* @param[in] labels Labels vector to match classified classes
* @param[in] topNCount Number of top classifications to pick.
+ * @param[in] use_softmax Whether softmax scaling should be applied to model output.
* @return true if successful, false otherwise.
**/
bool GetClassificationResults(
TfLiteTensor* outputTensor,
std::vector<ClassificationResult>& vecResults,
- const std::vector <std::string>& labels, uint32_t topNCount) override;
+ const std::vector <std::string>& labels, uint32_t topNCount,
+ bool use_softmax = false) override;
private:
diff --git a/source/use_case/kws_asr/include/DsCnnMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
index c97dd9d..43bd390 100644
--- a/source/use_case/kws_asr/include/DsCnnMfcc.hpp
+++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp
@@ -14,8 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef KWS_ASR_DSCNN_MFCC_HPP
-#define KWS_ASR_DSCNN_MFCC_HPP
+#ifndef KWS_ASR_MICRONET_MFCC_HPP
+#define KWS_ASR_MICRONET_MFCC_HPP
#include "Mfcc.hpp"
@@ -23,8 +23,8 @@ namespace arm {
namespace app {
namespace audio {
- /* Class to provide DS-CNN specific MFCC calculation requirements. */
- class DsCnnMFCC : public MFCC {
+ /* Class to provide MicroNet specific MFCC calculation requirements. */
+ class MicroNetMFCC : public MFCC {
public:
static constexpr uint32_t ms_defaultSamplingFreq = 16000;
@@ -34,18 +34,18 @@ namespace audio {
static constexpr bool ms_defaultUseHtkMethod = true;
- explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen)
+ explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen)
: MFCC(MfccParams(
ms_defaultSamplingFreq, ms_defaultNumFbankBins,
ms_defaultMelLoFreq, ms_defaultMelHiFreq,
numFeats, frameLen, ms_defaultUseHtkMethod))
{}
- DsCnnMFCC() = delete;
- ~DsCnnMFCC() = default;
+ MicroNetMFCC() = delete;
+ ~MicroNetMFCC() = default;
};
} /* namespace audio */
} /* namespace app */
} /* namespace arm */
-#endif /* KWS_ASR_DSCNN_MFCC_HPP */
+#endif /* KWS_ASR_MICRONET_MFCC_HPP */
diff --git a/source/use_case/kws_asr/include/DsCnnModel.hpp b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp
index 92d96b9..22cf916 100644
--- a/source/use_case/kws_asr/include/DsCnnModel.hpp
+++ b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp
@@ -14,8 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef KWS_ASR_DSCNNMODEL_HPP
-#define KWS_ASR_DSCNNMODEL_HPP
+#ifndef KWS_ASR_MICRONETMODEL_HPP
+#define KWS_ASR_MICRONETMODEL_HPP
#include "Model.hpp"
@@ -33,12 +33,11 @@ namespace kws {
namespace arm {
namespace app {
-
- class DsCnnModel : public Model {
+ class MicroNetKwsModel : public Model {
public:
/* Indices for the expected model - based on input and output tensor shapes */
- static constexpr uint32_t ms_inputRowsIdx = 2;
- static constexpr uint32_t ms_inputColsIdx = 3;
+ static constexpr uint32_t ms_inputRowsIdx = 1;
+ static constexpr uint32_t ms_inputColsIdx = 2;
static constexpr uint32_t ms_outputRowsIdx = 2;
static constexpr uint32_t ms_outputColsIdx = 3;
@@ -55,7 +54,7 @@ namespace app {
private:
/* Maximum number of individual operations that can be enlisted. */
- static constexpr int ms_maxOpCnt = 10;
+ static constexpr int ms_maxOpCnt = 7;
/* A mutable op resolver instance. */
tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver;
@@ -64,4 +63,4 @@ namespace app {
} /* namespace app */
} /* namespace arm */
-#endif /* KWS_DSCNNMODEL_HPP */
+#endif /* KWS_ASR_MICRONETMODEL_HPP */