diff options
author | Richard Burton <richard.burton@arm.com> | 2022-05-04 09:45:02 +0100 |
---|---|---|
committer | Richard Burton <richard.burton@arm.com> | 2022-05-04 09:45:02 +0100 |
commit | 4e002791bc6781b549c6951cfe44f918289d7e82 (patch) | |
tree | b639243b5fa433657c207783a384bad1ed248536 /source/use_case/ad/src | |
parent | dd6d07b24bbf9023ebe8e8927be8aac3291d0f58 (diff) | |
download | ml-embedded-evaluation-kit-4e002791bc6781b549c6951cfe44f918289d7e82.tar.gz |
MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's
Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37
Diffstat (limited to 'source/use_case/ad/src')
-rw-r--r-- | source/use_case/ad/src/AdPostProcessing.cc | 115 | ||||
-rw-r--r-- | source/use_case/ad/src/AdProcessing.cc | 208 | ||||
-rw-r--r-- | source/use_case/ad/src/MainLoop.cc | 7 | ||||
-rw-r--r-- | source/use_case/ad/src/UseCaseHandler.cc | 291 |
4 files changed, 279 insertions, 342 deletions
diff --git a/source/use_case/ad/src/AdPostProcessing.cc b/source/use_case/ad/src/AdPostProcessing.cc deleted file mode 100644 index c461875..0000000 --- a/source/use_case/ad/src/AdPostProcessing.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "AdPostProcessing.hpp" -#include "log_macros.h" - -#include <numeric> -#include <cmath> -#include <string> - -namespace arm { -namespace app { - - template<typename T> - std::vector<float> Dequantize(TfLiteTensor* tensor) { - - if (tensor == nullptr) { - printf_err("Tensor is null pointer can not dequantize.\n"); - return std::vector<float>(); - } - T* tensorData = tflite::GetTensorData<T>(tensor); - - uint32_t totalOutputSize = 1; - for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){ - totalOutputSize *= tensor->dims->data[inputDim]; - } - - /* For getting the floating point values, we need quantization parameters */ - QuantParams quantParams = GetTensorQuantParams(tensor); - - std::vector<float> dequantizedOutput(totalOutputSize); - - for (size_t i = 0; i < totalOutputSize; ++i) { - dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset); - } - - return dequantizedOutput; - } - - void Softmax(std::vector<float>& inputVector) { - auto start = inputVector.begin(); - auto end = inputVector.end(); - - /* Fix for numerical stability and apply exp. */ - float maxValue = *std::max_element(start, end); - for (auto it = start; it!=end; ++it) { - *it = std::exp((*it) - maxValue); - } - - float sumExp = std::accumulate(start, end, 0.0f); - - for (auto it = start; it!=end; ++it) { - *it = (*it)/sumExp; - } - } - - int8_t OutputIndexFromFileName(std::string wavFileName) { - /* Filename is assumed in the form machine_id_00.wav */ - std::string delimiter = "_"; /* First character used to split the file name up. */ - size_t delimiterStart; - std::string subString; - size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ - - for (size_t i = 0; i < machineIdxInString; ++i) { - delimiterStart = wavFileName.find(delimiter); - subString = wavFileName.substr(0, delimiterStart); - wavFileName.erase(0, delimiterStart + delimiter.length()); - } - - /* At this point substring should be 00.wav */ - delimiter = "."; /* Second character used to split the file name up. */ - delimiterStart = subString.find(delimiter); - subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; - - auto is_number = [](const std::string& str) -> bool - { - std::string::const_iterator it = str.begin(); - while (it != str.end() && std::isdigit(*it)) ++it; - return !str.empty() && it == str.end(); - }; - - const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1; - - /* Return corresponding index in the output vector. */ - if (machineIdx == 0) { - return 0; - } else if (machineIdx == 2) { - return 1; - } else if (machineIdx == 4) { - return 2; - } else if (machineIdx == 6) { - return 3; - } else { - printf_err("%d is an invalid machine index \n", machineIdx); - return -1; - } - } - - template std::vector<float> Dequantize<uint8_t>(TfLiteTensor* tensor); - template std::vector<float> Dequantize<int8_t>(TfLiteTensor* tensor); -} /* namespace app */ -} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/ad/src/AdProcessing.cc b/source/use_case/ad/src/AdProcessing.cc new file mode 100644 index 0000000..a33131c --- /dev/null +++ b/source/use_case/ad/src/AdProcessing.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AdProcessing.hpp" + +#include "AdModel.hpp" + +namespace arm { +namespace app { + +AdPreProcess::AdPreProcess(TfLiteTensor* inputTensor, + uint32_t melSpectrogramFrameLen, + uint32_t melSpectrogramFrameStride, + float adModelTrainingMean): + m_validInstance{false}, + m_melSpectrogramFrameLen{melSpectrogramFrameLen}, + m_melSpectrogramFrameStride{melSpectrogramFrameStride}, + /**< Model is trained on features downsampled 2x */ + m_inputResizeScale{2}, + /**< We are choosing to move by 20 frames across the audio for each inference. */ + m_numMelSpecVectorsInAudioStride{20}, + m_audioDataStride{m_numMelSpecVectorsInAudioStride * melSpectrogramFrameStride}, + m_melSpec{melSpectrogramFrameLen} +{ + if (!inputTensor) { + printf_err("Invalid input tensor provided to pre-process\n"); + return; + } + + TfLiteIntArray* inputShape = inputTensor->dims; + + if (!inputShape) { + printf_err("Invalid input tensor dims\n"); + return; + } + + const uint32_t kNumRows = inputShape->data[AdModel::ms_inputRowsIdx]; + const uint32_t kNumCols = inputShape->data[AdModel::ms_inputColsIdx]; + + /* Deduce the data length required for 1 inference from the network parameters. */ + this->m_audioDataWindowSize = (((this->m_inputResizeScale * kNumCols) - 1) * + melSpectrogramFrameStride) + + melSpectrogramFrameLen; + this->m_numReusedFeatureVectors = kNumRows - + (this->m_numMelSpecVectorsInAudioStride / + this->m_inputResizeScale); + this->m_melSpec.Init(); + + /* Creating a Mel Spectrogram sliding window for the data required for 1 inference. + * "resizing" done here by multiplying stride by resize scale. */ + this->m_melWindowSlider = audio::SlidingWindow<const int16_t>( + nullptr, /* to be populated later. */ + this->m_audioDataWindowSize, + melSpectrogramFrameLen, + melSpectrogramFrameStride * this->m_inputResizeScale); + + /* Construct feature calculation function. */ + this->m_featureCalc = GetFeatureCalculator(this->m_melSpec, inputTensor, + this->m_numReusedFeatureVectors, + adModelTrainingMean); + this->m_validInstance = true; +} + +bool AdPreProcess::DoPreProcess(const void* input, size_t inputSize) +{ + /* Check that we have a valid instance. */ + if (!this->m_validInstance) { + printf_err("Invalid pre-processor instance\n"); + return false; + } + + /* We expect that we can traverse the size with which the MEL spectrogram + * sliding window was initialised with. */ + if (!input || inputSize < this->m_audioDataWindowSize) { + printf_err("Invalid input provided for pre-processing\n"); + return false; + } + + /* We moved to the next window - set the features sliding to the new address. */ + this->m_melWindowSlider.Reset(static_cast<const int16_t*>(input)); + + /* The first window does not have cache ready. */ + const bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedFeatureVectors > 0; + + /* Start calculating features inside one audio sliding window. */ + while (this->m_melWindowSlider.HasNext()) { + const int16_t* melSpecWindow = this->m_melWindowSlider.Next(); + std::vector<int16_t> melSpecAudioData = std::vector<int16_t>( + melSpecWindow, + melSpecWindow + this->m_melSpectrogramFrameLen); + + /* Compute features for this window and write them to input tensor. */ + this->m_featureCalc(melSpecAudioData, + this->m_melWindowSlider.Index(), + useCache, + this->m_numMelSpecVectorsInAudioStride, + this->m_inputResizeScale); + } + + return true; +} + +uint32_t AdPreProcess::GetAudioWindowSize() +{ + return this->m_audioDataWindowSize; +} + +uint32_t AdPreProcess::GetAudioDataStride() +{ + return this->m_audioDataStride; +} + +void AdPreProcess::SetAudioWindowIndex(uint32_t idx) +{ + this->m_audioWindowIndex = idx; +} + +AdPostProcess::AdPostProcess(TfLiteTensor* outputTensor) : + m_outputTensor {outputTensor} +{} + +bool AdPostProcess::DoPostProcess() +{ + switch (this->m_outputTensor->type) { + case kTfLiteInt8: + this->Dequantize<int8_t>(); + break; + default: + printf_err("Unsupported tensor type"); + return false; + } + + math::MathUtils::SoftmaxF32(this->m_dequantizedOutputVec); + return true; +} + +float AdPostProcess::GetOutputValue(uint32_t index) +{ + if (index < this->m_dequantizedOutputVec.size()) { + return this->m_dequantizedOutputVec[index]; + } + printf_err("Invalid index for output\n"); + return 0.0; +} + +std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)> +GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, + TfLiteTensor* inputTensor, + size_t cacheSize, + float trainingMean) +{ + std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + + auto* quantParams = static_cast<TfLiteAffineQuantization*>(quant.params); + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + melSpecFeatureCalc = FeatureCalc<int8_t>( + inputTensor, + cacheSize, + [=, &melSpec](std::vector<int16_t>& audioDataWindow) { + return melSpec.MelSpecComputeQuant<int8_t>( + audioDataWindow, + quantScale, + quantOffset, + trainingMean); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + } else { + melSpecFeatureCalc = FeatureCalc<float>( + inputTensor, + cacheSize, + [=, &melSpec]( + std::vector<int16_t>& audioDataWindow) { + return melSpec.ComputeMelSpec( + audioDataWindow, + trainingMean); + }); + } + return melSpecFeatureCalc; +} + +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc index 23d1e51..140359b 100644 --- a/source/use_case/ad/src/MainLoop.cc +++ b/source/use_case/ad/src/MainLoop.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "hal.h" /* Brings in platform definitions */ #include "InputFiles.hpp" /* For input data */ #include "AdModel.hpp" /* Model class for running inference */ #include "UseCaseCommonUtils.hpp" /* Utils functions */ @@ -63,8 +62,8 @@ void main_loop() caseContext.Set<arm::app::Profiler&>("profiler", profiler); caseContext.Set<arm::app::Model&>("model", model); caseContext.Set<uint32_t>("clipIndex", 0); - caseContext.Set<int>("frameLength", g_FrameLength); - caseContext.Set<int>("frameStride", g_FrameStride); + caseContext.Set<uint32_t>("frameLength", g_FrameLength); + caseContext.Set<uint32_t>("frameStride", g_FrameStride); caseContext.Set<float>("scoreThreshold", g_ScoreThreshold); caseContext.Set<float>("trainingMean", g_TrainingMean); diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc index 5585f36..0179d6b 100644 --- a/source/use_case/ad/src/UseCaseHandler.cc +++ b/source/use_case/ad/src/UseCaseHandler.cc @@ -24,8 +24,8 @@ #include "AudioUtils.hpp" #include "ImageUtils.hpp" #include "UseCaseCommonUtils.hpp" -#include "AdPostProcessing.hpp" #include "log_macros.h" +#include "AdProcessing.hpp" namespace arm { namespace app { @@ -39,32 +39,17 @@ namespace app { **/ static bool PresentInferenceResult(float result, float threshold); - /** - * @brief Returns a function to perform feature calculation and populates input tensor data with - * MelSpe data. - * - * Input tensor data type check is performed to choose correct MFCC feature data type. - * If tensor has an integer data type then original features are quantised. - * - * Warning: mfcc calculator provided as input must have the same life scope as returned function. - * - * @param[in] melSpec MFCC feature calculator. - * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors). - * @param[in] trainingMean Training mean. - * @return function function to be called providing audio sample and sliding window index. - */ - static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)> - GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, - TfLiteTensor* inputTensor, - size_t cacheSize, - float trainingMean); - - /* Vibration classification handler */ + /** @brief Given a wav file name return AD model output index. + * @param[in] wavFileName Audio WAV filename. + * File name should be in format anything_goes_XX_here.wav + * where XX is the machine ID e.g. 00, 02, 04 or 06 + * @return AD model output index as 8 bit integer. + **/ + static int8_t OutputIndexFromFileName(std::string wavFileName); + + /* Anomaly Detection inference handler */ bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) { - auto& profiler = ctx.Get<Profiler&>("profiler"); - constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; @@ -81,8 +66,9 @@ namespace app { return false; } - const auto frameLength = ctx.Get<int>("frameLength"); - const auto frameStride = ctx.Get<int>("frameStride"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength"); + const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride"); const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); const auto trainingMean = ctx.Get<float>("trainingMean"); auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); @@ -95,21 +81,13 @@ namespace app { return false; } - TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t kNumRows = inputShape->data[1]; - const uint32_t kNumCols = inputShape->data[2]; + AdPreProcess preProcess{ + inputTensor, + melSpecFrameLength, + melSpecFrameStride, + trainingMean}; - audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength); - melSpec.Init(); - - /* Deduce the data length required for 1 inference from the network parameters. */ - const uint8_t inputResizeScale = 2; - const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength; - - /* We are choosing to move by 20 frames across the audio for each inference. */ - const uint8_t nMelSpecVectorsInAudioStride = 20; - - auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride; + AdPostProcess postProcess{outputTensor}; do { hal_lcd_clear(COLOR_BLACK); @@ -122,29 +100,12 @@ namespace app { return false; } - /* Creating a Mel Spectrogram sliding window for the data required for 1 inference. - * "resizing" done here by multiplying stride by resize scale. */ - auto audioMelSpecWindowSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - audioDataWindowSize, frameLength, - frameStride * inputResizeScale); - /* Creating a sliding window through the whole audio clip. */ auto audioDataSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - get_audio_array_size(currentIndex), - audioDataWindowSize, audioDataStride); - - /* Calculate number of the feature vectors in the window overlap region taking into account resizing. - * These feature vectors will be reused.*/ - auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale); - - /* Construct feature calculation function. */ - auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor, - numberOfReusedFeatureVectors, trainingMean); - if (!melSpecFeatureCalc){ - return false; - } + get_audio_array(currentIndex), + get_audio_array_size(currentIndex), + preProcess.GetAudioWindowSize(), + preProcess.GetAudioDataStride()); /* Result is an averaged sum over inferences. */ float result = 0; @@ -152,30 +113,18 @@ namespace app { /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex)); + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + info("Running inference on audio clip %" PRIu32 " => %s\n", + currentIndex, get_filename(currentIndex)); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { - const int16_t *inferenceWindow = audioDataSlider.Next(); - - /* We moved to the next window - set the features sliding to the new address. */ - audioMelSpecWindowSlider.Reset(inferenceWindow); + const int16_t* inferenceWindow = audioDataSlider.Next(); - /* The first window does not have cache ready. */ - bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; - - /* Start calculating features inside one audio sliding window. */ - while (audioMelSpecWindowSlider.HasNext()) { - const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next(); - std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(melSpecWindow, - melSpecWindow + frameLength); - - /* Compute features for this window and write them to input tensor. */ - melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(), - useCache, nMelSpecVectorsInAudioStride, inputResizeScale); - } + preProcess.SetAudioWindowIndex(audioDataSlider.Index()); + preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize()); info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); @@ -185,13 +134,11 @@ namespace app { return false; } - /* Use the negative softmax score of the corresponding index as the outlier score */ - std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor); - Softmax(dequantOutput); - result += -dequantOutput[machineOutputIndex]; + postProcess.DoPostProcess(); + result += 0 - postProcess.GetOutputValue(machineOutputIndex); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(outputTensor); + DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ @@ -218,7 +165,6 @@ namespace app { return true; } - static bool PresentInferenceResult(float result, float threshold) { constexpr uint32_t dataPsnTxtStartX1 = 20; @@ -251,148 +197,47 @@ namespace app { return true; } - /** - * @brief Generic feature calculator factory. - * - * Returns lambda function to compute features using features cache. - * Real features math is done by a lambda function provided as a parameter. - * Features are written to input tensor memory. - * - * @tparam T feature vector type. - * @param inputTensor model input tensor pointer. - * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. - * @param compute features calculator function. - * @return lambda function to compute features. - */ - template<class T> - std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> - FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, - std::function<std::vector<T> (std::vector<int16_t>& )> compute) + static int8_t OutputIndexFromFileName(std::string wavFileName) { - /* Feature cache to be captured by lambda function*/ - static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize); - - return [=](std::vector<int16_t>& audioDataWindow, - size_t index, - bool useCache, - size_t featuresOverlapIndex, - size_t resizeScale) - { - T *tensorData = tflite::GetTensorData<T>(inputTensor); - std::vector<T> features; - - /* Reuse features from cache if cache is ready and sliding windows overlap. - * Overlap is in the beginning of sliding window with a size of a feature cache. */ - if (useCache && index < featureCache.size()) { - features = std::move(featureCache[index]); - } else { - features = std::move(compute(audioDataWindow)); - } - auto size = features.size() / resizeScale; - auto sizeBytes = sizeof(T); + /* Filename is assumed in the form machine_id_00.wav */ + std::string delimiter = "_"; /* First character used to split the file name up. */ + size_t delimiterStart; + std::string subString; + size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ + + for (size_t i = 0; i < machineIdxInString; ++i) { + delimiterStart = wavFileName.find(delimiter); + subString = wavFileName.substr(0, delimiterStart); + wavFileName.erase(0, delimiterStart + delimiter.length()); + } - /* Input should be transposed and "resized" by skipping elements. */ - for (size_t outIndex = 0; outIndex < size; outIndex++) { - std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes); - } + /* At this point substring should be 00.wav */ + delimiter = "."; /* Second character used to split the file name up. */ + delimiterStart = subString.find(delimiter); + subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; - /* Start renewing cache as soon iteration goes out of the windows overlap. */ - if (index >= featuresOverlapIndex / resizeScale) { - featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features); - } + auto is_number = [](const std::string& str) -> bool + { + std::string::const_iterator it = str.begin(); + while (it != str.end() && std::isdigit(*it)) ++it; + return !str.empty() && it == str.end(); }; - } - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)> - FeatureCalc<int8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)> - FeatureCalc<uint8_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<uint8_t> (std::vector<int16_t>&)> compute); - - template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)> - FeatureCalc<int16_t>(TfLiteTensor* inputTensor, - size_t cacheSize, - std::function<std::vector<int16_t> (std::vector<int16_t>&)> compute); - - template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)> - FeatureCalc<float>(TfLiteTensor *inputTensor, - size_t cacheSize, - std::function<std::vector<float>(std::vector<int16_t>&)> compute); - - - static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)> - GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean) - { - std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc; - - TfLiteQuantization quant = inputTensor->quantization; - - if (kTfLiteAffineQuantization == quant.type) { - - auto *quantParams = (TfLiteAffineQuantization *) quant.params; - const float quantScale = quantParams->scale->data[0]; - const int quantOffset = quantParams->zero_point->data[0]; - - switch (inputTensor->type) { - case kTfLiteInt8: { - melSpecFeatureCalc = FeatureCalc<int8_t>(inputTensor, - cacheSize, - [=, &melSpec](std::vector<int16_t>& audioDataWindow) { - return melSpec.MelSpecComputeQuant<int8_t>( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - case kTfLiteUInt8: { - melSpecFeatureCalc = FeatureCalc<uint8_t>(inputTensor, - cacheSize, - [=, &melSpec](std::vector<int16_t>& audioDataWindow) { - return melSpec.MelSpecComputeQuant<uint8_t>( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - case kTfLiteInt16: { - melSpecFeatureCalc = FeatureCalc<int16_t>(inputTensor, - cacheSize, - [=, &melSpec](std::vector<int16_t>& audioDataWindow) { - return melSpec.MelSpecComputeQuant<int16_t>( - audioDataWindow, - quantScale, - quantOffset, - trainingMean); - } - ); - break; - } - default: - printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); - } - + const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1; + + /* Return corresponding index in the output vector. */ + if (machineIdx == 0) { + return 0; + } else if (machineIdx == 2) { + return 1; + } else if (machineIdx == 4) { + return 2; + } else if (machineIdx == 6) { + return 3; } else { - melSpecFeatureCalc = melSpecFeatureCalc = FeatureCalc<float>(inputTensor, - cacheSize, - [=, &melSpec]( - std::vector<int16_t>& audioDataWindow) { - return melSpec.ComputeMelSpec( - audioDataWindow, - trainingMean); - }); + printf_err("%d is an invalid machine index \n", machineIdx); + return -1; } - return melSpecFeatureCalc; } } /* namespace app */ |