summaryrefslogtreecommitdiff
path: root/source/use_case/ad/src/UseCaseHandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/ad/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/ad/src/UseCaseHandler.cc89
1 files changed, 44 insertions, 45 deletions
diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc
index 222aaf3..c71fdeb 100644
--- a/source/use_case/ad/src/UseCaseHandler.cc
+++ b/source/use_case/ad/src/UseCaseHandler.cc
@@ -1,6 +1,6 @@
/*
- * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
- * SPDX-License-Identifier: Apache-2.0
+ * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates
+ * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,16 +16,16 @@
*/
#include "UseCaseHandler.hpp"
-#include "AdModel.hpp"
-#include "InputFiles.hpp"
-#include "Classifier.hpp"
-#include "hal.h"
#include "AdMelSpectrogram.hpp"
+#include "AdModel.hpp"
+#include "AdProcessing.hpp"
#include "AudioUtils.hpp"
+#include "Classifier.hpp"
#include "ImageUtils.hpp"
+#include "InputFiles.hpp"
#include "UseCaseCommonUtils.hpp"
+#include "hal.h"
#include "log_macros.h"
-#include "AdProcessing.hpp"
namespace arm {
namespace app {
@@ -44,7 +44,7 @@ namespace app {
* File name should be in format anything_goes_XX_here.wav
* where XX is the machine ID e.g. 00, 02, 04 or 06
* @return AD model output index as 8 bit integer.
- **/
+ **/
static int8_t OutputIndexFromFileName(std::string wavFileName);
/* Anomaly Detection inference handler */
@@ -57,7 +57,7 @@ namespace app {
/* If the request has a valid size, set the audio index */
if (clipIndex < NUMBER_OF_FILES) {
- if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
+ if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) {
return false;
}
}
@@ -66,26 +66,22 @@ namespace app {
return false;
}
- auto& profiler = ctx.Get<Profiler&>("profiler");
+ auto& profiler = ctx.Get<Profiler&>("profiler");
const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength");
const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride");
- const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
- const auto trainingMean = ctx.Get<float>("trainingMean");
- auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
+ const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
+ const auto trainingMean = ctx.Get<float>("trainingMean");
+ auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
TfLiteTensor* outputTensor = model.GetOutputTensor(0);
- TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* inputTensor = model.GetInputTensor(0);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return false;
}
- AdPreProcess preProcess{
- inputTensor,
- melSpecFrameLength,
- melSpecFrameStride,
- trainingMean};
+ AdPreProcess preProcess{inputTensor, melSpecFrameLength, melSpecFrameStride, trainingMean};
AdPostProcess postProcess{outputTensor};
@@ -95,17 +91,17 @@ namespace app {
auto currentIndex = ctx.Get<uint32_t>("clipIndex");
/* Get the output index to look at based on id in the filename. */
- int8_t machineOutputIndex = OutputIndexFromFileName(get_filename(currentIndex));
+ int8_t machineOutputIndex = OutputIndexFromFileName(GetFilename(currentIndex));
if (machineOutputIndex == -1) {
return false;
}
/* Creating a sliding window through the whole audio clip. */
- auto audioDataSlider = audio::SlidingWindow<const int16_t>(
- get_audio_array(currentIndex),
- get_audio_array_size(currentIndex),
- preProcess.GetAudioWindowSize(),
- preProcess.GetAudioDataStride());
+ auto audioDataSlider =
+ audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex),
+ GetAudioArraySize(currentIndex),
+ preProcess.GetAudioWindowSize(),
+ preProcess.GetAudioDataStride());
/* Result is an averaged sum over inferences. */
float result = 0;
@@ -113,11 +109,11 @@ namespace app {
/* Display message on the LCD - inference running. */
std::string str_inf{"Running inference... "};
hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
info("Running inference on audio clip %" PRIu32 " => %s\n",
- currentIndex, get_filename(currentIndex));
+ currentIndex,
+ GetFilename(currentIndex));
/* Start sliding through audio clip. */
while (audioDataSlider.HasNext()) {
@@ -126,7 +122,8 @@ namespace app {
preProcess.SetAudioWindowIndex(audioDataSlider.Index());
preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize());
- info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
+ info("Inference %zu/%zu\n",
+ audioDataSlider.Index() + 1,
audioDataSlider.TotalStrides() + 1);
/* Run inference over this audio clip sliding window */
@@ -139,7 +136,7 @@ namespace app {
#if VERIFY_TEST_OUTPUT
DumpTensor(outputTensor);
-#endif /* VERIFY_TEST_OUTPUT */
+#endif /* VERIFY_TEST_OUTPUT */
} /* while (audioDataSlider.HasNext()) */
/* Use average over whole clip as final score. */
@@ -148,8 +145,7 @@ namespace app {
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
hal_lcd_display_text(
- str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
+ str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
ctx.Set<float>("result", result);
if (!PresentInferenceResult(result, scoreThreshold)) {
@@ -158,7 +154,7 @@ namespace app {
profiler.PrintProfilingResult();
- IncrementAppCtxIfmIdx(ctx,"clipIndex");
+ IncrementAppCtxIfmIdx(ctx, "clipIndex");
} while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
@@ -176,8 +172,10 @@ namespace app {
/* Display each result */
uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
- std::string anomalyScore = std::string{"Average anomaly score is: "} + std::to_string(result);
- std::string anomalyThreshold = std::string("Anomaly threshold is: ") + std::to_string(threshold);
+ std::string anomalyScore =
+ std::string{"Average anomaly score is: "} + std::to_string(result);
+ std::string anomalyThreshold =
+ std::string("Anomaly threshold is: ") + std::to_string(threshold);
std::string anomalyResult;
if (result > threshold) {
@@ -187,8 +185,7 @@ namespace app {
}
hal_lcd_display_text(
- anomalyScore.c_str(), anomalyScore.size(),
- dataPsnTxtStartX1, rowIdx1, false);
+ anomalyScore.c_str(), anomalyScore.size(), dataPsnTxtStartX1, rowIdx1, false);
info("%s\n", anomalyScore.c_str());
info("%s\n", anomalyThreshold.c_str());
@@ -200,26 +197,28 @@ namespace app {
static int8_t OutputIndexFromFileName(std::string wavFileName)
{
/* Filename is assumed in the form machine_id_00.wav */
- std::string delimiter = "_"; /* First character used to split the file name up. */
+ std::string delimiter = "_"; /* First character used to split the file name up. */
size_t delimiterStart;
std::string subString;
- size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */
+ size_t machineIdxInString =
+ 3; /* Which part of the file name the machine id should be at. */
for (size_t i = 0; i < machineIdxInString; ++i) {
delimiterStart = wavFileName.find(delimiter);
- subString = wavFileName.substr(0, delimiterStart);
+ subString = wavFileName.substr(0, delimiterStart);
wavFileName.erase(0, delimiterStart + delimiter.length());
}
/* At this point substring should be 00.wav */
- delimiter = "."; /* Second character used to split the file name up. */
+ delimiter = "."; /* Second character used to split the file name up. */
delimiterStart = subString.find(delimiter);
- subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
+ subString =
+ (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
- auto is_number = [](const std::string& str) -> bool
- {
+ auto is_number = [](const std::string& str) -> bool {
std::string::const_iterator it = str.begin();
- while (it != str.end() && std::isdigit(*it)) ++it;
+ while (it != str.end() && std::isdigit(*it))
+ ++it;
return !str.empty() && it == str.end();
};