diff options
Diffstat (limited to 'source/use_case/img_class/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/img_class/src/UseCaseHandler.cc | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index c68d816..5cc3959 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -59,6 +59,7 @@ namespace app { } TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; @@ -74,13 +75,12 @@ namespace app { const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx]; /* Set up pre and post-processing. */ - ImgClassPreProcess preprocess = ImgClassPreProcess(&model); + ImgClassPreProcess preProcess = ImgClassPreProcess(inputTensor, model.IsDataSigned()); std::vector<ClassificationResult> results; - ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model, - ctx.Get<std::vector<std::string>&>("labels"), results); - - UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); + ImgClassPostProcess postProcess = ImgClassPostProcess(outputTensor, + ctx.Get<ImgClassClassifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"), + results); do { hal_lcd_clear(COLOR_BLACK); @@ -113,17 +113,18 @@ namespace app { inputTensor->bytes : IMAGE_DATA_SIZE; /* Run the pre-processing, inference and post-processing. */ - if (!runner.PreProcess(imgSrc, imgSz)) { + if (!preProcess.DoPreProcess(imgSrc, imgSz)) { + printf_err("Pre-processing failed."); return false; } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } @@ -136,7 +137,6 @@ namespace app { ctx.Set<std::vector<ClassificationResult>>("results", results); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ |