summaryrefslogtreecommitdiff
path: root/source/use_case/img_class/src/UseCaseHandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/img_class/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/img_class/src/UseCaseHandler.cc22
1 files changed, 11 insertions, 11 deletions
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc
index c68d816..5cc3959 100644
--- a/source/use_case/img_class/src/UseCaseHandler.cc
+++ b/source/use_case/img_class/src/UseCaseHandler.cc
@@ -59,6 +59,7 @@ namespace app {
}
TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return false;
@@ -74,13 +75,12 @@ namespace app {
const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
/* Set up pre and post-processing. */
- ImgClassPreProcess preprocess = ImgClassPreProcess(&model);
+ ImgClassPreProcess preProcess = ImgClassPreProcess(inputTensor, model.IsDataSigned());
std::vector<ClassificationResult> results;
- ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model,
- ctx.Get<std::vector<std::string>&>("labels"), results);
-
- UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
+ ImgClassPostProcess postProcess = ImgClassPostProcess(outputTensor,
+ ctx.Get<ImgClassClassifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"),
+ results);
do {
hal_lcd_clear(COLOR_BLACK);
@@ -113,17 +113,18 @@ namespace app {
inputTensor->bytes : IMAGE_DATA_SIZE;
/* Run the pre-processing, inference and post-processing. */
- if (!runner.PreProcess(imgSrc, imgSz)) {
+ if (!preProcess.DoPreProcess(imgSrc, imgSz)) {
+ printf_err("Pre-processing failed.");
return false;
}
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
@@ -136,7 +137,6 @@ namespace app {
ctx.Set<std::vector<ClassificationResult>>("results", results);
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
arm::app::DumpTensor(outputTensor);
#endif /* VERIFY_TEST_OUTPUT */