summaryrefslogtreecommitdiff
path: root/source/use_case/kws
diff options
context:
space:
mode:
authorKshitij Sisodia <kshitij.sisodia@arm.com>2022-05-06 09:13:03 +0100
committerKshitij Sisodia <kshitij.sisodia@arm.com>2022-05-06 17:11:41 +0100
commitaa4bcb14d0cbee910331545dd2fc086b58c37170 (patch)
treee67a43a43f61c6f8b6aad19018b0827baf7e31a6 /source/use_case/kws
parentfcca863bafd5f33522bc14c23dde4540e264ec94 (diff)
downloadml-embedded-evaluation-kit-aa4bcb14d0cbee910331545dd2fc086b58c37170.tar.gz
MLECO-3183: Refactoring application sources
Platform agnostic application sources are moved into application api module with their own independent CMake projects. Changes for MLECO-3080 also included - they create CMake projects individial API's (again, platform agnostic) that dependent on the common logic. The API for KWS_API "joint" API has been removed and now the use case relies on individual KWS, and ASR API libraries. Change-Id: I1f7748dc767abb3904634a04e0991b74ac7b756d Signed-off-by: Kshitij Sisodia <kshitij.sisodia@arm.com>
Diffstat (limited to 'source/use_case/kws')
-rw-r--r--source/use_case/kws/include/KwsProcessing.hpp138
-rw-r--r--source/use_case/kws/include/KwsResult.hpp63
-rw-r--r--source/use_case/kws/include/MicroNetKwsMfcc.hpp50
-rw-r--r--source/use_case/kws/include/MicroNetKwsModel.hpp59
-rw-r--r--source/use_case/kws/src/KwsProcessing.cc212
-rw-r--r--source/use_case/kws/src/MainLoop.cc34
-rw-r--r--source/use_case/kws/src/MicroNetKwsModel.cc56
-rw-r--r--source/use_case/kws/usecase.cmake3
8 files changed, 31 insertions, 584 deletions
diff --git a/source/use_case/kws/include/KwsProcessing.hpp b/source/use_case/kws/include/KwsProcessing.hpp
deleted file mode 100644
index d3de3b3..0000000
--- a/source/use_case/kws/include/KwsProcessing.hpp
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef KWS_PROCESSING_HPP
-#define KWS_PROCESSING_HPP
-
-#include <AudioUtils.hpp>
-#include "BaseProcessing.hpp"
-#include "Model.hpp"
-#include "Classifier.hpp"
-#include "MicroNetKwsMfcc.hpp"
-
-#include <functional>
-
-namespace arm {
-namespace app {
-
- /**
- * @brief Pre-processing class for Keyword Spotting use case.
- * Implements methods declared by BasePreProcess and anything else needed
- * to populate input tensors ready for inference.
- */
- class KwsPreProcess : public BasePreProcess {
-
- public:
- /**
- * @brief Constructor
- * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
- * @param[in] numFeatures How many MFCC features to use.
- * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
- * for an inference.
- * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
- * sliding a window through the audio sample.
- * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
- **/
- explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames,
- int mfccFrameLength, int mfccFrameStride);
-
- /**
- * @brief Should perform pre-processing of 'raw' input audio data and load it into
- * TFLite Micro input tensors ready for inference.
- * @param[in] input Pointer to the data that pre-processing will work on.
- * @param[in] inputSize Size of the input data.
- * @return true if successful, false otherwise.
- **/
- bool DoPreProcess(const void* input, size_t inputSize) override;
-
- size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */
- size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */
- size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */
-
- private:
- TfLiteTensor* m_inputTensor; /* Model input tensor. */
- const int m_mfccFrameLength;
- const int m_mfccFrameStride;
- const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */
-
- audio::MicroNetKwsMFCC m_mfcc;
- audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
- size_t m_numMfccVectorsInAudioStride;
- size_t m_numReusedMfccVectors;
- std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator;
-
- /**
- * @brief Returns a function to perform feature calculation and populates input tensor data with
- * MFCC data.
- *
- * Input tensor data type check is performed to choose correct MFCC feature data type.
- * If tensor has an integer data type then original features are quantised.
- *
- * Warning: MFCC calculator provided as input must have the same life scope as returned function.
- *
- * @param[in] mfcc MFCC feature calculator.
- * @param[in,out] inputTensor Input tensor pointer to store calculated features.
- * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
- * @return Function to be called providing audio sample and sliding window index.
- */
- std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
- TfLiteTensor* inputTensor,
- size_t cacheSize);
-
- template<class T>
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
- FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
- std::function<std::vector<T> (std::vector<int16_t>& )> compute);
- };
-
- /**
- * @brief Post-processing class for Keyword Spotting use case.
- * Implements methods declared by BasePostProcess and anything else needed
- * to populate result vector.
- */
- class KwsPostProcess : public BasePostProcess {
-
- private:
- TfLiteTensor* m_outputTensor; /* Model output tensor. */
- Classifier& m_kwsClassifier; /* KWS Classifier object. */
- const std::vector<std::string>& m_labels; /* KWS Labels. */
- std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */
-
- public:
- /**
- * @brief Constructor
- * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
- * @param[in] classifier Classifier object used to get top N results from classification.
- * @param[in] labels Vector of string labels to identify each output of the model.
- * @param[in/out] results Vector of classification results to store decoded outputs.
- **/
- KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
- const std::vector<std::string>& labels,
- std::vector<ClassificationResult>& results);
-
- /**
- * @brief Should perform post-processing of the result of inference then
- * populate KWS result data for any later use.
- * @return true if successful, false otherwise.
- **/
- bool DoPostProcess() override;
- };
-
-} /* namespace app */
-} /* namespace arm */
-
-#endif /* KWS_PROCESSING_HPP */ \ No newline at end of file
diff --git a/source/use_case/kws/include/KwsResult.hpp b/source/use_case/kws/include/KwsResult.hpp
deleted file mode 100644
index 38f32b4..0000000
--- a/source/use_case/kws/include/KwsResult.hpp
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef KWS_RESULT_HPP
-#define KWS_RESULT_HPP
-
-#include "ClassificationResult.hpp"
-
-#include <vector>
-
-namespace arm {
-namespace app {
-namespace kws {
-
- using ResultVec = std::vector<arm::app::ClassificationResult>;
-
- /* Structure for holding kws result. */
- class KwsResult {
-
- public:
- ResultVec m_resultVec; /* Container for "thresholded" classification results. */
- float m_timeStamp; /* Audio timestamp for this result. */
- uint32_t m_inferenceNumber; /* Corresponding inference number. */
- float m_threshold; /* Threshold value for `m_resultVec`. */
-
- KwsResult() = delete;
- KwsResult(ResultVec& resultVec,
- const float timestamp,
- const uint32_t inferenceIdx,
- const float scoreThreshold) {
-
- this->m_threshold = scoreThreshold;
- this->m_timeStamp = timestamp;
- this->m_inferenceNumber = inferenceIdx;
-
- this->m_resultVec = ResultVec();
- for (auto & i : resultVec) {
- if (i.m_normalisedVal >= this->m_threshold) {
- this->m_resultVec.emplace_back(i);
- }
- }
- }
- ~KwsResult() = default;
- };
-
-} /* namespace kws */
-} /* namespace app */
-} /* namespace arm */
-
-#endif /* KWS_RESULT_HPP */ \ No newline at end of file
diff --git a/source/use_case/kws/include/MicroNetKwsMfcc.hpp b/source/use_case/kws/include/MicroNetKwsMfcc.hpp
deleted file mode 100644
index b2565a3..0000000
--- a/source/use_case/kws/include/MicroNetKwsMfcc.hpp
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef KWS_MICRONET_MFCC_HPP
-#define KWS_MICRONET_MFCC_HPP
-
-#include "Mfcc.hpp"
-
-namespace arm {
-namespace app {
-namespace audio {
-
- /* Class to provide MicroNet specific MFCC calculation requirements. */
- class MicroNetKwsMFCC : public MFCC {
-
- public:
- static constexpr uint32_t ms_defaultSamplingFreq = 16000;
- static constexpr uint32_t ms_defaultNumFbankBins = 40;
- static constexpr uint32_t ms_defaultMelLoFreq = 20;
- static constexpr uint32_t ms_defaultMelHiFreq = 4000;
- static constexpr bool ms_defaultUseHtkMethod = true;
-
- explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen)
- : MFCC(MfccParams(
- ms_defaultSamplingFreq, ms_defaultNumFbankBins,
- ms_defaultMelLoFreq, ms_defaultMelHiFreq,
- numFeats, frameLen, ms_defaultUseHtkMethod))
- {}
- MicroNetKwsMFCC() = delete;
- ~MicroNetKwsMFCC() = default;
- };
-
-} /* namespace audio */
-} /* namespace app */
-} /* namespace arm */
-
-#endif /* KWS_MICRONET_MFCC_HPP */ \ No newline at end of file
diff --git a/source/use_case/kws/include/MicroNetKwsModel.hpp b/source/use_case/kws/include/MicroNetKwsModel.hpp
deleted file mode 100644
index 3259c45..0000000
--- a/source/use_case/kws/include/MicroNetKwsModel.hpp
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef KWS_MICRONETMODEL_HPP
-#define KWS_MICRONETMODEL_HPP
-
-#include "Model.hpp"
-
-extern const int g_FrameLength;
-extern const int g_FrameStride;
-extern const float g_ScoreThreshold;
-
-namespace arm {
-namespace app {
-
- class MicroNetKwsModel : public Model {
- public:
- /* Indices for the expected model - based on input and output tensor shapes */
- static constexpr uint32_t ms_inputRowsIdx = 1;
- static constexpr uint32_t ms_inputColsIdx = 2;
- static constexpr uint32_t ms_outputRowsIdx = 2;
- static constexpr uint32_t ms_outputColsIdx = 3;
-
- protected:
- /** @brief Gets the reference to op resolver interface class. */
- const tflite::MicroOpResolver& GetOpResolver() override;
-
- /** @brief Adds operations to the op resolver instance. */
- bool EnlistOperations() override;
-
- const uint8_t* ModelPointer() override;
-
- size_t ModelSize() override;
-
- private:
- /* Maximum number of individual operations that can be enlisted. */
- static constexpr int ms_maxOpCnt = 7;
-
- /* A mutable op resolver instance. */
- tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver;
- };
-
-} /* namespace app */
-} /* namespace arm */
-
-#endif /* KWS_MICRONETMODEL_HPP */
diff --git a/source/use_case/kws/src/KwsProcessing.cc b/source/use_case/kws/src/KwsProcessing.cc
deleted file mode 100644
index 328709d..0000000
--- a/source/use_case/kws/src/KwsProcessing.cc
+++ /dev/null
@@ -1,212 +0,0 @@
-/*
- * Copyright (c) 2022 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "KwsProcessing.hpp"
-#include "ImageUtils.hpp"
-#include "log_macros.h"
-#include "MicroNetKwsModel.hpp"
-
-namespace arm {
-namespace app {
-
- KwsPreProcess::KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numMfccFrames,
- int mfccFrameLength, int mfccFrameStride
- ):
- m_inputTensor{inputTensor},
- m_mfccFrameLength{mfccFrameLength},
- m_mfccFrameStride{mfccFrameStride},
- m_numMfccFrames{numMfccFrames},
- m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)}
- {
- this->m_mfcc.Init();
-
- /* Deduce the data length required for 1 inference from the network parameters. */
- this->m_audioDataWindowSize = this->m_numMfccFrames * this->m_mfccFrameStride +
- (this->m_mfccFrameLength - this->m_mfccFrameStride);
-
- /* Creating an MFCC feature sliding window for the data required for 1 inference. */
- this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(nullptr, this->m_audioDataWindowSize,
- this->m_mfccFrameLength, this->m_mfccFrameStride);
-
- /* For longer audio clips we choose to move by half the audio window size
- * => for a 1 second window size there is an overlap of 0.5 seconds. */
- this->m_audioDataStride = this->m_audioDataWindowSize / 2;
-
- /* To have the previously calculated features re-usable, stride must be multiple
- * of MFCC features window stride. Reduce stride through audio if needed. */
- if (0 != this->m_audioDataStride % this->m_mfccFrameStride) {
- this->m_audioDataStride -= this->m_audioDataStride % this->m_mfccFrameStride;
- }
-
- this->m_numMfccVectorsInAudioStride = this->m_audioDataStride / this->m_mfccFrameStride;
-
- /* Calculate number of the feature vectors in the window overlap region.
- * These feature vectors will be reused.*/
- this->m_numReusedMfccVectors = this->m_mfccSlidingWindow.TotalStrides() + 1
- - this->m_numMfccVectorsInAudioStride;
-
- /* Construct feature calculation function. */
- this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_inputTensor,
- this->m_numReusedMfccVectors);
-
- if (!this->m_mfccFeatureCalculator) {
- printf_err("Feature calculator not initialized.");
- }
- }
-
- bool KwsPreProcess::DoPreProcess(const void* data, size_t inputSize)
- {
- UNUSED(inputSize);
- if (data == nullptr) {
- printf_err("Data pointer is null");
- }
-
- /* Set the features sliding window to the new address. */
- auto input = static_cast<const int16_t*>(data);
- this->m_mfccSlidingWindow.Reset(input);
-
- /* Cache is only usable if we have more than 1 inference in an audio clip. */
- bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedMfccVectors > 0;
-
- /* Use a sliding window to calculate MFCC features frame by frame. */
- while (this->m_mfccSlidingWindow.HasNext()) {
- const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next();
-
- std::vector<int16_t> mfccFrameAudioData = std::vector<int16_t>(mfccWindow,
- mfccWindow + this->m_mfccFrameLength);
-
- /* Compute features for this window and write them to input tensor. */
- this->m_mfccFeatureCalculator(mfccFrameAudioData, this->m_mfccSlidingWindow.Index(),
- useCache, this->m_numMfccVectorsInAudioStride);
- }
-
- debug("Input tensor populated \n");
-
- return true;
- }
-
- /**
- * @brief Generic feature calculator factory.
- *
- * Returns lambda function to compute features using features cache.
- * Real features math is done by a lambda function provided as a parameter.
- * Features are written to input tensor memory.
- *
- * @tparam T Feature vector type.
- * @param[in] inputTensor Model input tensor pointer.
- * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap.
- * @param[in] compute Features calculator function.
- * @return Lambda function to compute features.
- */
- template<class T>
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
- KwsPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
- std::function<std::vector<T> (std::vector<int16_t>& )> compute)
- {
- /* Feature cache to be captured by lambda function. */
- static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
-
- return [=](std::vector<int16_t>& audioDataWindow,
- size_t index,
- bool useCache,
- size_t featuresOverlapIndex)
- {
- T* tensorData = tflite::GetTensorData<T>(inputTensor);
- std::vector<T> features;
-
- /* Reuse features from cache if cache is ready and sliding windows overlap.
- * Overlap is in the beginning of sliding window with a size of a feature cache. */
- if (useCache && index < featureCache.size()) {
- features = std::move(featureCache[index]);
- } else {
- features = std::move(compute(audioDataWindow));
- }
- auto size = features.size();
- auto sizeBytes = sizeof(T) * size;
- std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
-
- /* Start renewing cache as soon iteration goes out of the windows overlap. */
- if (index >= featuresOverlapIndex) {
- featureCache[index - featuresOverlapIndex] = std::move(features);
- }
- };
- }
-
- template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
- KwsPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
-
- template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
- KwsPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor,
- size_t cacheSize,
- std::function<std::vector<float>(std::vector<int16_t>&)> compute);
-
-
- std::function<void (std::vector<int16_t>&, int, bool, size_t)>
- KwsPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
- {
- std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
-
- TfLiteQuantization quant = inputTensor->quantization;
-
- if (kTfLiteAffineQuantization == quant.type) {
- auto *quantParams = (TfLiteAffineQuantization *) quant.params;
- const float quantScale = quantParams->scale->data[0];
- const int quantOffset = quantParams->zero_point->data[0];
-
- switch (inputTensor->type) {
- case kTfLiteInt8: {
- mfccFeatureCalc = this->FeatureCalc<int8_t>(inputTensor,
- cacheSize,
- [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
- quantScale,
- quantOffset);
- }
- );
- break;
- }
- default:
- printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
- }
- } else {
- mfccFeatureCalc = this->FeatureCalc<float>(inputTensor, cacheSize,
- [&mfcc](std::vector<int16_t>& audioDataWindow) {
- return mfcc.MfccCompute(audioDataWindow); }
- );
- }
- return mfccFeatureCalc;
- }
-
- KwsPostProcess::KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
- const std::vector<std::string>& labels,
- std::vector<ClassificationResult>& results)
- :m_outputTensor{outputTensor},
- m_kwsClassifier{classifier},
- m_labels{labels},
- m_results{results}
- {}
-
- bool KwsPostProcess::DoPostProcess()
- {
- return this->m_kwsClassifier.GetClassificationResults(
- this->m_outputTensor, this->m_results,
- this->m_labels, 1, true);
- }
-
-} /* namespace app */
-} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc
index e590c4a..3c35a7f 100644
--- a/source/use_case/kws/src/MainLoop.cc
+++ b/source/use_case/kws/src/MainLoop.cc
@@ -21,7 +21,18 @@
#include "Labels.hpp" /* For label strings. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
-#include "log_macros.h"
+#include "log_macros.h" /* Logging functions */
+#include "BufAttributes.hpp" /* Buffer attributes to be applied */
+
+namespace arm {
+namespace app {
+namespace kws {
+ static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
+ extern uint8_t *GetModelPointer();
+ extern size_t GetModelLen();
+} /* namespace kws */
+} /* namespace app */
+} /* namespace arm */
using KwsClassifier = arm::app::Classifier;
@@ -53,11 +64,22 @@ void main_loop()
arm::app::MicroNetKwsModel model; /* Model wrapper object. */
/* Load the model. */
- if (!model.Init()) {
+ if (!model.Init(arm::app::kws::tensorArena,
+ sizeof(arm::app::kws::tensorArena),
+ arm::app::kws::GetModelPointer(),
+ arm::app::kws::GetModelLen())) {
printf_err("Failed to initialise model\n");
return;
}
+#if !defined(ARM_NPU)
+ /* If it is not a NPU build check if the model contains a NPU operator */
+ if (model.ContainsEthosUOperator()) {
+ printf_err("No driver support for Ethos-U operator found in the model.\n");
+ return;
+ }
+#endif /* ARM_NPU */
+
/* Instantiate application context. */
arm::app::ApplicationContext caseContext;
@@ -65,9 +87,9 @@ void main_loop()
caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("clipIndex", 0);
- caseContext.Set<int>("frameLength", g_FrameLength);
- caseContext.Set<int>("frameStride", g_FrameStride);
- caseContext.Set<float>("scoreThreshold", g_ScoreThreshold); /* Normalised score threshold. */
+ caseContext.Set<int>("frameLength", arm::app::kws::g_FrameLength);
+ caseContext.Set<int>("frameStride", arm::app::kws::g_FrameStride);
+ caseContext.Set<float>("scoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */
KwsClassifier classifier; /* classifier wrapper object. */
caseContext.Set<arm::app::Classifier&>("classifier", classifier);
@@ -114,4 +136,4 @@ void main_loop()
}
} while (executionSuccessful && bUseMenu);
info("Main loop terminated.\n");
-} \ No newline at end of file
+}
diff --git a/source/use_case/kws/src/MicroNetKwsModel.cc b/source/use_case/kws/src/MicroNetKwsModel.cc
deleted file mode 100644
index 1c38525..0000000
--- a/source/use_case/kws/src/MicroNetKwsModel.cc
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "MicroNetKwsModel.hpp"
-#include "log_macros.h"
-
-const tflite::MicroOpResolver& arm::app::MicroNetKwsModel::GetOpResolver()
-{
- return this->m_opResolver;
-}
-
-bool arm::app::MicroNetKwsModel::EnlistOperations()
-{
- this->m_opResolver.AddReshape();
- this->m_opResolver.AddAveragePool2D();
- this->m_opResolver.AddConv2D();
- this->m_opResolver.AddDepthwiseConv2D();
- this->m_opResolver.AddFullyConnected();
- this->m_opResolver.AddRelu();
-
-#if defined(ARM_NPU)
- if (kTfLiteOk == this->m_opResolver.AddEthosU()) {
- info("Added %s support to op resolver\n",
- tflite::GetString_ETHOSU());
- } else {
- printf_err("Failed to add Arm NPU support to op resolver.");
- return false;
- }
-#endif /* ARM_NPU */
- return true;
-}
-
-extern uint8_t* GetModelPointer();
-const uint8_t* arm::app::MicroNetKwsModel::ModelPointer()
-{
- return GetModelPointer();
-}
-
-extern size_t GetModelLen();
-size_t arm::app::MicroNetKwsModel::ModelSize()
-{
- return GetModelLen();
-} \ No newline at end of file
diff --git a/source/use_case/kws/usecase.cmake b/source/use_case/kws/usecase.cmake
index 9f3736e..d9985c7 100644
--- a/source/use_case/kws/usecase.cmake
+++ b/source/use_case/kws/usecase.cmake
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#----------------------------------------------------------------------------
+# Append the API to use for this use case
+list(APPEND ${use_case}_API_LIST "kws")
USER_OPTION(${use_case}_FILE_PATH "Directory with custom WAV input files, or path to a single WAV file, to use in the evaluation application."
${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/samples/
@@ -96,4 +98,5 @@ generate_tflite_code(
MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH}
DESTINATION ${SRC_GEN_DIR}
EXPRESSIONS ${EXTRA_MODEL_CODE}
+ NAMESPACE "arm" "app" "kws"
)