diff options
author | Richard Burton <richard.burton@arm.com> | 2022-05-04 09:45:02 +0100 |
---|---|---|
committer | Richard Burton <richard.burton@arm.com> | 2022-05-04 09:45:02 +0100 |
commit | 4e002791bc6781b549c6951cfe44f918289d7e82 (patch) | |
tree | b639243b5fa433657c207783a384bad1ed248536 /source/use_case/ad/include | |
parent | dd6d07b24bbf9023ebe8e8927be8aac3291d0f58 (diff) | |
download | ml-embedded-evaluation-kit-4e002791bc6781b549c6951cfe44f918289d7e82.tar.gz |
MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's
Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37
Diffstat (limited to 'source/use_case/ad/include')
-rw-r--r-- | source/use_case/ad/include/AdModel.hpp | 8 | ||||
-rw-r--r-- | source/use_case/ad/include/AdPostProcessing.hpp | 50 | ||||
-rw-r--r-- | source/use_case/ad/include/AdProcessing.hpp | 230 |
3 files changed, 237 insertions, 51 deletions
diff --git a/source/use_case/ad/include/AdModel.hpp b/source/use_case/ad/include/AdModel.hpp index 8d914c4..2195a7c 100644 --- a/source/use_case/ad/include/AdModel.hpp +++ b/source/use_case/ad/include/AdModel.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,12 @@ namespace arm { namespace app { class AdModel : public Model { + + public: + /* Indices for the expected model - based on input tensor shape */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + protected: /** @brief Gets the reference to op resolver interface class */ const tflite::MicroOpResolver& GetOpResolver() override; diff --git a/source/use_case/ad/include/AdPostProcessing.hpp b/source/use_case/ad/include/AdPostProcessing.hpp deleted file mode 100644 index 7eaec84..0000000 --- a/source/use_case/ad/include/AdPostProcessing.hpp +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef ADPOSTPROCESSING_HPP -#define ADPOSTPROCESSING_HPP - -#include "TensorFlowLiteMicro.hpp" - -#include <vector> - -namespace arm { -namespace app { - - /** @brief Dequantize TensorFlow Lite Micro tensor. - * @param[in] tensor Pointer to the TensorFlow Lite Micro tensor to be dequantized. - * @return Vector with the dequantized tensor values. - **/ - template<typename T> - std::vector<float> Dequantize(TfLiteTensor* tensor); - - /** - * @brief Calculates the softmax of vector in place. **/ - void Softmax(std::vector<float>& inputVector); - - - /** @brief Given a wav file name return AD model output index. - * @param[in] wavFileName Audio WAV filename. - * File name should be in format anything_goes_XX_here.wav - * where XX is the machine ID e.g. 00, 02, 04 or 06 - * @return AD model output index as 8 bit integer. - **/ - int8_t OutputIndexFromFileName(std::string wavFileName); - -} /* namespace app */ -} /* namespace arm */ - -#endif /* ADPOSTPROCESSING_HPP */ diff --git a/source/use_case/ad/include/AdProcessing.hpp b/source/use_case/ad/include/AdProcessing.hpp new file mode 100644 index 0000000..9abf6f1 --- /dev/null +++ b/source/use_case/ad/include/AdProcessing.hpp @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AD_PROCESSING_HPP +#define AD_PROCESSING_HPP + +#include "BaseProcessing.hpp" +#include "AudioUtils.hpp" +#include "AdMelSpectrogram.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for anomaly detection use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class AdPreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor for AdPreProcess class objects + * @param[in] inputTensor input tensor pointer from the tensor arena. + * @param[in] melSpectrogramFrameLen MEL spectrogram's frame length + * @param[in] melSpectrogramFrameStride MEL spectrogram's frame stride + * @param[in] adModelTrainingMean Training mean for the Anomaly detection model being used. + */ + explicit AdPreProcess(TfLiteTensor* inputTensor, + uint32_t melSpectrogramFrameLen, + uint32_t melSpectrogramFrameStride, + float adModelTrainingMean); + + ~AdPreProcess() = default; + + /** + * @brief Function to invoke pre-processing and populate the input vector + * @param input pointer to input data. For anomaly detection, this is the pointer to + * the audio data. + * @param inputSize Size of the data being passed in for pre-processing. + * @return True if successful, false otherwise. + */ + bool DoPreProcess(const void* input, size_t inputSize) override; + + /** + * @brief Getter function for audio window size computed when constructing + * the class object. + * @return Audio window size as 32 bit unsigned integer. + */ + uint32_t GetAudioWindowSize(); + + /** + * @brief Getter function for audio window stride computed when constructing + * the class object. + * @return Audio window stride as 32 bit unsigned integer. + */ + uint32_t GetAudioDataStride(); + + /** + * @brief Setter function for current audio index. This is only used for evaluating + * if previously computed features can be re-used from cache. + */ + void SetAudioWindowIndex(uint32_t idx); + + private: + bool m_validInstance{false}; /**< Indicates the current object is valid. */ + uint32_t m_melSpectrogramFrameLen{}; /**< MEL spectrogram's window frame length */ + uint32_t m_melSpectrogramFrameStride{}; /**< MEL spectrogram's window frame stride */ + uint8_t m_inputResizeScale{}; /**< Downscaling factor for the MEL energy matrix. */ + uint32_t m_numMelSpecVectorsInAudioStride{}; /**< Number of frames to move across the audio. */ + uint32_t m_audioDataWindowSize{}; /**< Audio window size computed based on other parameters. */ + uint32_t m_audioDataStride{}; /**< Audio window stride computed. */ + uint32_t m_numReusedFeatureVectors{}; /**< Number of MEL vectors that can be re-used */ + uint32_t m_audioWindowIndex{}; /**< Current audio window index (from audio's sliding window) */ + + audio::SlidingWindow<const int16_t> m_melWindowSlider; /**< Internal MEL spectrogram window slider */ + audio::AdMelSpectrogram m_melSpec; /**< MEL spectrogram computation object */ + std::function<void + (std::vector<int16_t>&, int, bool, size_t, size_t)> m_featureCalc; /**< Feature calculator object */ + }; + + class AdPostProcess : public BasePostProcess { + public: + /** + * @brief Constructor for AdPostProcess object. + * @param[in] outputTensor Output tensor pointer. + */ + explicit AdPostProcess(TfLiteTensor* outputTensor); + + ~AdPostProcess() = default; + + /** + * @brief Function to do the post-processing on the output tensor. + * @return True if successful, false otherwise. + */ + bool DoPostProcess() override; + + /** + * @brief Getter function for an element from the de-quantised output vector. + * @param index Index of the element to be retrieved. + * @return index represented as a 32 bit floating point number. + */ + float GetOutputValue(uint32_t index); + + private: + TfLiteTensor* m_outputTensor{}; /**< Output tensor pointer */ + std::vector<float> m_dequantizedOutputVec{}; /**< Internal output vector */ + + /** + * @brief De-quantizes and flattens the output tensor into a vector. + * @tparam T template parameter to indicate data type. + * @return True if successful, false otherwise. + */ + template<typename T> + bool Dequantize() + { + TfLiteTensor* tensor = this->m_outputTensor; + if (tensor == nullptr) { + printf_err("Invalid output tensor.\n"); + return false; + } + T* tensorData = tflite::GetTensorData<T>(tensor); + + uint32_t totalOutputSize = 1; + for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){ + totalOutputSize *= tensor->dims->data[inputDim]; + } + + /* For getting the floating point values, we need quantization parameters */ + QuantParams quantParams = GetTensorQuantParams(tensor); + + this->m_dequantizedOutputVec = std::vector<float>(totalOutputSize, 0); + + for (size_t i = 0; i < totalOutputSize; ++i) { + this->m_dequantizedOutputVec[i] = quantParams.scale * (tensorData[i] - quantParams.offset); + } + + return true; + } + }; + + /* Templated instances available: */ + template bool AdPostProcess::Dequantize<int8_t>(); + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T feature vector type. + * @param inputTensor model input tensor pointer. + * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. + * @param compute features calculator function. + * @return lambda function to compute features. + */ + template<class T> + std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> + FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function<std::vector<T> (std::vector<int16_t>& )> compute) + { + /* Feature cache to be captured by lambda function*/ + static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); + + return [=](std::vector<int16_t>& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex, + size_t resizeScale) + { + T* tensorData = tflite::GetTensorData<T>(inputTensor); + std::vector<T> features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size() / resizeScale; + auto sizeBytes = sizeof(T); + + /* Input should be transposed and "resized" by skipping elements. */ + for (size_t outIndex = 0; outIndex < size; outIndex++) { + std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes); + } + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex / resizeScale) { + featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features); + } + }; + } + + template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)> + FeatureCalc<int8_t>(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); + + template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)> + FeatureCalc<float>(TfLiteTensor *inputTensor, + size_t cacheSize, + std::function<std::vector<float>(std::vector<int16_t>&)> compute); + + std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)> + GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, + TfLiteTensor* inputTensor, + size_t cacheSize, + float trainingMean); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* AD_PROCESSING_HPP */ |