diff options
Diffstat (limited to 'source/use_case/object_detection/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/object_detection/src/UseCaseHandler.cc | 74 |
1 files changed, 36 insertions, 38 deletions
diff --git a/source/use_case/object_detection/src/UseCaseHandler.cc b/source/use_case/object_detection/src/UseCaseHandler.cc index f3b317e..332d199 100644 --- a/source/use_case/object_detection/src/UseCaseHandler.cc +++ b/source/use_case/object_detection/src/UseCaseHandler.cc @@ -19,6 +19,7 @@ #include "YoloFastestModel.hpp" #include "UseCaseCommonUtils.hpp" #include "DetectorPostProcessing.hpp" +#include "DetectorPreProcessing.hpp" #include "hal.h" #include "log_macros.h" @@ -33,7 +34,7 @@ namespace app { * @param[in] results Vector of detection results to be displayed. * @return true if successful, false otherwise. **/ - static bool PresentInferenceResult(const std::vector<arm::app::object_detection::DetectionResult>& results); + static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results); /** * @brief Draw boxes directly on the LCD for all detected objects. @@ -43,12 +44,12 @@ namespace app { * @param[in] imgDownscaleFactor How much image has been downscaled on LCD. **/ static void DrawDetectionBoxes( - const std::vector<arm::app::object_detection::DetectionResult>& results, + const std::vector<object_detection::DetectionResult>& results, uint32_t imgStartX, uint32_t imgStartY, uint32_t imgDownscaleFactor); - /* Object detection classification handler. */ + /* Object detection inference handler. */ bool ObjectDetectionHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) { auto& profiler = ctx.Get<Profiler&>("profiler"); @@ -75,9 +76,11 @@ namespace app { return false; } - auto curImIdx = ctx.Get<uint32_t>("imgIndex"); + auto initialImgIdx = ctx.Get<uint32_t>("imgIndex"); TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor0 = model.GetOutputTensor(0); + TfLiteTensor* outputTensor1 = model.GetOutputTensor(1); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); @@ -89,71 +92,66 @@ namespace app { TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t nCols = inputShape->data[arm::app::YoloFastestModel::ms_inputColsIdx]; - const uint32_t nRows = inputShape->data[arm::app::YoloFastestModel::ms_inputRowsIdx]; + const int inputImgCols = inputShape->data[YoloFastestModel::ms_inputColsIdx]; + const int inputImgRows = inputShape->data[YoloFastestModel::ms_inputRowsIdx]; - /* Get pre/post-processing objects. */ - auto& postp = ctx.Get<object_detection::DetectorPostprocessing&>("postprocess"); + /* Set up pre and post-processing. */ + DetectorPreProcess preProcess = DetectorPreProcess(inputTensor, true, model.IsDataSigned()); + std::vector<object_detection::DetectionResult> results; + DetectorPostProcess postProcess = DetectorPostProcess(outputTensor0, outputTensor1, + results, inputImgRows, inputImgCols); do { /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; - const uint8_t* curr_image = get_img_array(ctx.Get<uint32_t>("imgIndex")); + const uint8_t* currImage = get_img_array(ctx.Get<uint32_t>("imgIndex")); - /* Copy over the data and convert to grayscale */ - auto* dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8); + auto dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8); const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ? inputTensor->bytes : IMAGE_DATA_SIZE; - /* Convert to gray scale and populate input tensor. */ - image::RgbToGrayscale(curr_image, dstPtr, copySz); + /* Run the pre-processing, inference and post-processing. */ + if (!preProcess.DoPreProcess(currImage, copySz)) { + printf_err("Pre-processing failed."); + return false; + } /* Display image on the LCD. */ hal_lcd_display_image( - (channelsImageDisplayed == 3) ? curr_image : dstPtr, - nCols, nRows, channelsImageDisplayed, + (channelsImageDisplayed == 3) ? currImage : dstPtr, + inputImgCols, inputImgRows, channelsImageDisplayed, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); - /* If the data is signed. */ - if (model.IsDataSigned()) { - image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); - } - /* Display message on the LCD - inference running. */ hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); /* Run inference over this image. */ info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"), get_filename(ctx.Get<uint32_t>("imgIndex"))); if (!RunInference(model, profiler)) { + printf_err("Inference failed."); + return false; + } + + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } /* Erase. */ str_inf = std::string(str_inf.size(), ' '); hal_lcd_display_text(str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - - /* Detector post-processing*/ - std::vector<object_detection::DetectionResult> results; - TfLiteTensor* modelOutput0 = model.GetOutputTensor(0); - TfLiteTensor* modelOutput1 = model.GetOutputTensor(1); - postp.RunPostProcessing( - nRows, - nCols, - modelOutput0, - modelOutput1, - results); + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); /* Draw boxes. */ DrawDetectionBoxes(results, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); #if VERIFY_TEST_OUTPUT - arm::app::DumpTensor(modelOutput0); - arm::app::DumpTensor(modelOutput1); + DumpTensor(modelOutput0); + DumpTensor(modelOutput1); #endif /* VERIFY_TEST_OUTPUT */ if (!PresentInferenceResult(results)) { @@ -164,12 +162,12 @@ namespace app { IncrementAppCtxIfmIdx(ctx,"imgIndex"); - } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx); + } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImgIdx); return true; } - static bool PresentInferenceResult(const std::vector<arm::app::object_detection::DetectionResult>& results) + static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results) { hal_lcd_set_text_color(COLOR_GREEN); @@ -186,7 +184,7 @@ namespace app { return true; } - static void DrawDetectionBoxes(const std::vector<arm::app::object_detection::DetectionResult>& results, + static void DrawDetectionBoxes(const std::vector<object_detection::DetectionResult>& results, uint32_t imgStartX, uint32_t imgStartY, uint32_t imgDownscaleFactor) |