summaryrefslogtreecommitdiff
path: root/source/use_case/kws/include/KwsProcessing.hpp
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-13 11:58:28 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-13 11:58:28 +0100
commite6398cd54642a6a420b14003ad62309448dd724e (patch)
tree8c72ae7fcc7badb4ae161d07ef63eda38f6ff65e /source/use_case/kws/include/KwsProcessing.hpp
parent7e56d8f55c770204deaa2de644990828b9ff083b (diff)
downloadml-embedded-evaluation-kit-e6398cd54642a6a420b14003ad62309448dd724e.tar.gz
MLECO-3075: Add KWS use case API
Removed some of the templates for feature calculation that we are unlikely to ever use. We might be able to refactor the feature caching and feature calculator code in the future to better integrate it with with PreProcess API. Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: Ic0c0c581c71e2553d41ff72cd1ed3b3efa64fa92
Diffstat (limited to 'source/use_case/kws/include/KwsProcessing.hpp')
-rw-r--r--source/use_case/kws/include/KwsProcessing.hpp135
1 files changed, 135 insertions, 0 deletions
diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp
new file mode 100644
index 0000000..abf20ab
--- /dev/null
+++ b/source/use_case/kws/include/KwsProcessing.hpp
@@ -0,0 +1,135 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef KWS_PROCESSING_HPP
+#define KWS_PROCESSING_HPP
+
+#include <AudioUtils.hpp>
+#include "BaseProcessing.hpp"
+#include "Model.hpp"
+#include "Classifier.hpp"
+#include "MicroNetKwsMfcc.hpp"
+
+#include <functional>
+
+namespace arm {
+namespace app {
+
+ /**
+ * @brief Pre-processing class for Keyword Spotting use case.
+ * Implements methods declared by BasePreProcess and anything else needed
+ * to populate input tensors ready for inference.
+ */
+ class KWSPreProcess : public BasePreProcess {
+
+ public:
+ /**
+ * @brief Constructor
+ * @param[in] model Pointer to the the KWS Model object.
+ * @param[in] numFeatures How many MFCC features to use.
+ * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
+ * sliding a window through the audio sample.
+ * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
+ **/
+ explicit KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride);
+
+ /**
+ * @brief Should perform pre-processing of 'raw' input audio data and load it into
+ * TFLite Micro input tensors ready for inference.
+ * @param[in] input Pointer to the data that pre-processing will work on.
+ * @param[in] inputSize Size of the input data.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */
+ size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */
+ size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */
+
+ private:
+ const int m_mfccFrameLength;
+ const int m_mfccFrameStride;
+
+ audio::MicroNetKwsMFCC m_mfcc;
+ audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
+ size_t m_numMfccVectorsInAudioStride;
+ size_t m_numReusedMfccVectors;
+ std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator;
+
+ /**
+ * @brief Returns a function to perform feature calculation and populates input tensor data with
+ * MFCC data.
+ *
+ * Input tensor data type check is performed to choose correct MFCC feature data type.
+ * If tensor has an integer data type then original features are quantised.
+ *
+ * Warning: MFCC calculator provided as input must have the same life scope as returned function.
+ *
+ * @param[in] mfcc MFCC feature calculator.
+ * @param[in,out] inputTensor Input tensor pointer to store calculated features.
+ * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
+ * @return Function to be called providing audio sample and sliding window index.
+ */
+ std::function<void (std::vector<int16_t>&, int, bool, size_t)>
+ GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
+ TfLiteTensor* inputTensor,
+ size_t cacheSize);
+
+ template<class T>
+ std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
+ FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
+ std::function<std::vector<T> (std::vector<int16_t>& )> compute);
+ };
+
+ /**
+ * @brief Post-processing class for Keyword Spotting use case.
+ * Implements methods declared by BasePostProcess and anything else needed
+ * to populate result vector.
+ */
+ class KWSPostProcess : public BasePostProcess {
+
+ private:
+ Classifier& m_kwsClassifier;
+ const std::vector<std::string>& m_labels;
+ std::vector<ClassificationResult>& m_results;
+
+ public:
+ const float m_scoreThreshold;
+ /**
+ * @brief Constructor
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] model Pointer to the the Image classification Model object.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in] results Vector of classification results to store decoded outputs.
+ * @param[in] scoreThreshold Predicted model score must be larger than this value to be accepted.
+ **/
+ KWSPostProcess(Classifier& classifier, Model* model,
+ const std::vector<std::string>& labels,
+ std::vector<ClassificationResult>& results,
+ float scoreThreshold);
+
+ /**
+ * @brief Should perform post-processing of the result of inference then populate
+ * populate KWS result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPostProcess() override;
+ };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* KWS_PROCESSING_HPP */ \ No newline at end of file