diff options
Diffstat (limited to 'source/use_case/ad/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/ad/src/UseCaseHandler.cc | 89 |
1 files changed, 44 insertions, 45 deletions
diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc index 222aaf3..c71fdeb 100644 --- a/source/use_case/ad/src/UseCaseHandler.cc +++ b/source/use_case/ad/src/UseCaseHandler.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,16 +16,16 @@ */ #include "UseCaseHandler.hpp" -#include "AdModel.hpp" -#include "InputFiles.hpp" -#include "Classifier.hpp" -#include "hal.h" #include "AdMelSpectrogram.hpp" +#include "AdModel.hpp" +#include "AdProcessing.hpp" #include "AudioUtils.hpp" +#include "Classifier.hpp" #include "ImageUtils.hpp" +#include "InputFiles.hpp" #include "UseCaseCommonUtils.hpp" +#include "hal.h" #include "log_macros.h" -#include "AdProcessing.hpp" namespace arm { namespace app { @@ -44,7 +44,7 @@ namespace app { * File name should be in format anything_goes_XX_here.wav * where XX is the machine ID e.g. 00, 02, 04 or 06 * @return AD model output index as 8 bit integer. - **/ + **/ static int8_t OutputIndexFromFileName(std::string wavFileName); /* Anomaly Detection inference handler */ @@ -57,7 +57,7 @@ namespace app { /* If the request has a valid size, set the audio index */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { + if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) { return false; } } @@ -66,26 +66,22 @@ namespace app { return false; } - auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& profiler = ctx.Get<Profiler&>("profiler"); const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength"); const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride"); - const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); - const auto trainingMean = ctx.Get<float>("trainingMean"); - auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); + const auto scoreThreshold = ctx.Get<float>("scoreThreshold"); + const auto trainingMean = ctx.Get<float>("trainingMean"); + auto startClipIdx = ctx.Get<uint32_t>("clipIndex"); TfLiteTensor* outputTensor = model.GetOutputTensor(0); - TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; } - AdPreProcess preProcess{ - inputTensor, - melSpecFrameLength, - melSpecFrameStride, - trainingMean}; + AdPreProcess preProcess{inputTensor, melSpecFrameLength, melSpecFrameStride, trainingMean}; AdPostProcess postProcess{outputTensor}; @@ -95,17 +91,17 @@ namespace app { auto currentIndex = ctx.Get<uint32_t>("clipIndex"); /* Get the output index to look at based on id in the filename. */ - int8_t machineOutputIndex = OutputIndexFromFileName(get_filename(currentIndex)); + int8_t machineOutputIndex = OutputIndexFromFileName(GetFilename(currentIndex)); if (machineOutputIndex == -1) { return false; } /* Creating a sliding window through the whole audio clip. */ - auto audioDataSlider = audio::SlidingWindow<const int16_t>( - get_audio_array(currentIndex), - get_audio_array_size(currentIndex), - preProcess.GetAudioWindowSize(), - preProcess.GetAudioDataStride()); + auto audioDataSlider = + audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex), + GetAudioArraySize(currentIndex), + preProcess.GetAudioWindowSize(), + preProcess.GetAudioDataStride()); /* Result is an averaged sum over inferences. */ float result = 0; @@ -113,11 +109,11 @@ namespace app { /* Display message on the LCD - inference running. */ std::string str_inf{"Running inference... "}; hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); info("Running inference on audio clip %" PRIu32 " => %s\n", - currentIndex, get_filename(currentIndex)); + currentIndex, + GetFilename(currentIndex)); /* Start sliding through audio clip. */ while (audioDataSlider.HasNext()) { @@ -126,7 +122,8 @@ namespace app { preProcess.SetAudioWindowIndex(audioDataSlider.Index()); preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize()); - info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + info("Inference %zu/%zu\n", + audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1); /* Run inference over this audio clip sliding window */ @@ -139,7 +136,7 @@ namespace app { #if VERIFY_TEST_OUTPUT DumpTensor(outputTensor); -#endif /* VERIFY_TEST_OUTPUT */ +#endif /* VERIFY_TEST_OUTPUT */ } /* while (audioDataSlider.HasNext()) */ /* Use average over whole clip as final score. */ @@ -148,8 +145,7 @@ namespace app { /* Erase. */ str_inf = std::string(str_inf.size(), ' '); hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); ctx.Set<float>("result", result); if (!PresentInferenceResult(result, scoreThreshold)) { @@ -158,7 +154,7 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"clipIndex"); + IncrementAppCtxIfmIdx(ctx, "clipIndex"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); @@ -176,8 +172,10 @@ namespace app { /* Display each result */ uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - std::string anomalyScore = std::string{"Average anomaly score is: "} + std::to_string(result); - std::string anomalyThreshold = std::string("Anomaly threshold is: ") + std::to_string(threshold); + std::string anomalyScore = + std::string{"Average anomaly score is: "} + std::to_string(result); + std::string anomalyThreshold = + std::string("Anomaly threshold is: ") + std::to_string(threshold); std::string anomalyResult; if (result > threshold) { @@ -187,8 +185,7 @@ namespace app { } hal_lcd_display_text( - anomalyScore.c_str(), anomalyScore.size(), - dataPsnTxtStartX1, rowIdx1, false); + anomalyScore.c_str(), anomalyScore.size(), dataPsnTxtStartX1, rowIdx1, false); info("%s\n", anomalyScore.c_str()); info("%s\n", anomalyThreshold.c_str()); @@ -200,26 +197,28 @@ namespace app { static int8_t OutputIndexFromFileName(std::string wavFileName) { /* Filename is assumed in the form machine_id_00.wav */ - std::string delimiter = "_"; /* First character used to split the file name up. */ + std::string delimiter = "_"; /* First character used to split the file name up. */ size_t delimiterStart; std::string subString; - size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ + size_t machineIdxInString = + 3; /* Which part of the file name the machine id should be at. */ for (size_t i = 0; i < machineIdxInString; ++i) { delimiterStart = wavFileName.find(delimiter); - subString = wavFileName.substr(0, delimiterStart); + subString = wavFileName.substr(0, delimiterStart); wavFileName.erase(0, delimiterStart + delimiter.length()); } /* At this point substring should be 00.wav */ - delimiter = "."; /* Second character used to split the file name up. */ + delimiter = "."; /* Second character used to split the file name up. */ delimiterStart = subString.find(delimiter); - subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; + subString = + (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; - auto is_number = [](const std::string& str) -> bool - { + auto is_number = [](const std::string& str) -> bool { std::string::const_iterator it = str.begin(); - while (it != str.end() && std::isdigit(*it)) ++it; + while (it != str.end() && std::isdigit(*it)) + ++it; return !str.empty() && it == str.end(); }; |