summaryrefslogtreecommitdiff
path: root/source/use_case/vww/src/UseCaseHandler.cc
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
commitb40ecf8522052809d2351677a96195d69e4d0c16 (patch)
tree8647dfdae7bcae0ec6d9564ba7a971819fdda431 /source/use_case/vww/src/UseCaseHandler.cc
parentc291144b7f08c21d08cdaf79cc64dc420ca70070 (diff)
downloadml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments Removed the inference runner code as it wasn't used Fixes to doc strings Consistent naming e.g. Asr/Kws instead of ASR/KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'source/use_case/vww/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/vww/src/UseCaseHandler.cc22
1 files changed, 10 insertions, 12 deletions
diff --git a/source/use_case/vww/src/UseCaseHandler.cc b/source/use_case/vww/src/UseCaseHandler.cc
index 7681f89..267e6c4 100644
--- a/source/use_case/vww/src/UseCaseHandler.cc
+++ b/source/use_case/vww/src/UseCaseHandler.cc
@@ -53,7 +53,7 @@ namespace app {
}
TfLiteTensor* inputTensor = model.GetInputTensor(0);
-
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return false;
@@ -75,15 +75,13 @@ namespace app {
const uint32_t displayChannels = 3;
/* Set up pre and post-processing. */
- VisualWakeWordPreProcess preprocess = VisualWakeWordPreProcess(&model);
+ VisualWakeWordPreProcess preProcess = VisualWakeWordPreProcess(inputTensor);
std::vector<ClassificationResult> results;
- VisualWakeWordPostProcess postprocess = VisualWakeWordPostProcess(
- ctx.Get<Classifier&>("classifier"), &model,
+ VisualWakeWordPostProcess postProcess = VisualWakeWordPostProcess(outputTensor,
+ ctx.Get<Classifier&>("classifier"),
ctx.Get<std::vector<std::string>&>("labels"), results);
- UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
-
do {
hal_lcd_clear(COLOR_BLACK);
@@ -115,17 +113,18 @@ namespace app {
inputTensor->bytes : IMAGE_DATA_SIZE;
/* Run the pre-processing, inference and post-processing. */
- if (!runner.PreProcess(imgSrc, imgSz)) {
+ if (!preProcess.DoPreProcess(imgSrc, imgSz)) {
+ printf_err("Pre-processing failed.");
return false;
}
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
@@ -138,7 +137,6 @@ namespace app {
ctx.Set<std::vector<ClassificationResult>>("results", results);
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
arm::app::DumpTensor(outputTensor);
#endif /* VERIFY_TEST_OUTPUT */