summaryrefslogtreecommitdiff
path: root/source/use_case/object_detection/src/UseCaseHandler.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/object_detection/src/UseCaseHandler.cc')
-rw-r--r--source/use_case/object_detection/src/UseCaseHandler.cc74
1 files changed, 36 insertions, 38 deletions
diff --git a/source/use_case/object_detection/src/UseCaseHandler.cc b/source/use_case/object_detection/src/UseCaseHandler.cc
index f3b317e..332d199 100644
--- a/source/use_case/object_detection/src/UseCaseHandler.cc
+++ b/source/use_case/object_detection/src/UseCaseHandler.cc
@@ -19,6 +19,7 @@
#include "YoloFastestModel.hpp"
#include "UseCaseCommonUtils.hpp"
#include "DetectorPostProcessing.hpp"
+#include "DetectorPreProcessing.hpp"
#include "hal.h"
#include "log_macros.h"
@@ -33,7 +34,7 @@ namespace app {
* @param[in] results Vector of detection results to be displayed.
* @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(const std::vector<arm::app::object_detection::DetectionResult>& results);
+ static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results);
/**
* @brief Draw boxes directly on the LCD for all detected objects.
@@ -43,12 +44,12 @@ namespace app {
* @param[in] imgDownscaleFactor How much image has been downscaled on LCD.
**/
static void DrawDetectionBoxes(
- const std::vector<arm::app::object_detection::DetectionResult>& results,
+ const std::vector<object_detection::DetectionResult>& results,
uint32_t imgStartX,
uint32_t imgStartY,
uint32_t imgDownscaleFactor);
- /* Object detection classification handler. */
+ /* Object detection inference handler. */
bool ObjectDetectionHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
{
auto& profiler = ctx.Get<Profiler&>("profiler");
@@ -75,9 +76,11 @@ namespace app {
return false;
}
- auto curImIdx = ctx.Get<uint32_t>("imgIndex");
+ auto initialImgIdx = ctx.Get<uint32_t>("imgIndex");
TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor0 = model.GetOutputTensor(0);
+ TfLiteTensor* outputTensor1 = model.GetOutputTensor(1);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
@@ -89,71 +92,66 @@ namespace app {
TfLiteIntArray* inputShape = model.GetInputShape(0);
- const uint32_t nCols = inputShape->data[arm::app::YoloFastestModel::ms_inputColsIdx];
- const uint32_t nRows = inputShape->data[arm::app::YoloFastestModel::ms_inputRowsIdx];
+ const int inputImgCols = inputShape->data[YoloFastestModel::ms_inputColsIdx];
+ const int inputImgRows = inputShape->data[YoloFastestModel::ms_inputRowsIdx];
- /* Get pre/post-processing objects. */
- auto& postp = ctx.Get<object_detection::DetectorPostprocessing&>("postprocess");
+ /* Set up pre and post-processing. */
+ DetectorPreProcess preProcess = DetectorPreProcess(inputTensor, true, model.IsDataSigned());
+ std::vector<object_detection::DetectionResult> results;
+ DetectorPostProcess postProcess = DetectorPostProcess(outputTensor0, outputTensor1,
+ results, inputImgRows, inputImgCols);
do {
/* Strings for presentation/logging. */
std::string str_inf{"Running inference... "};
- const uint8_t* curr_image = get_img_array(ctx.Get<uint32_t>("imgIndex"));
+ const uint8_t* currImage = get_img_array(ctx.Get<uint32_t>("imgIndex"));
- /* Copy over the data and convert to grayscale */
- auto* dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8);
+ auto dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8);
const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
inputTensor->bytes : IMAGE_DATA_SIZE;
- /* Convert to gray scale and populate input tensor. */
- image::RgbToGrayscale(curr_image, dstPtr, copySz);
+ /* Run the pre-processing, inference and post-processing. */
+ if (!preProcess.DoPreProcess(currImage, copySz)) {
+ printf_err("Pre-processing failed.");
+ return false;
+ }
/* Display image on the LCD. */
hal_lcd_display_image(
- (channelsImageDisplayed == 3) ? curr_image : dstPtr,
- nCols, nRows, channelsImageDisplayed,
+ (channelsImageDisplayed == 3) ? currImage : dstPtr,
+ inputImgCols, inputImgRows, channelsImageDisplayed,
dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
- /* If the data is signed. */
- if (model.IsDataSigned()) {
- image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
- }
-
/* Display message on the LCD - inference running. */
hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
/* Run inference over this image. */
info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
get_filename(ctx.Get<uint32_t>("imgIndex")));
if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
+ return false;
+ }
+
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
-
- /* Detector post-processing*/
- std::vector<object_detection::DetectionResult> results;
- TfLiteTensor* modelOutput0 = model.GetOutputTensor(0);
- TfLiteTensor* modelOutput1 = model.GetOutputTensor(1);
- postp.RunPostProcessing(
- nRows,
- nCols,
- modelOutput0,
- modelOutput1,
- results);
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
/* Draw boxes. */
DrawDetectionBoxes(results, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(modelOutput0);
- arm::app::DumpTensor(modelOutput1);
+ DumpTensor(modelOutput0);
+ DumpTensor(modelOutput1);
#endif /* VERIFY_TEST_OUTPUT */
if (!PresentInferenceResult(results)) {
@@ -164,12 +162,12 @@ namespace app {
IncrementAppCtxIfmIdx(ctx,"imgIndex");
- } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
+ } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImgIdx);
return true;
}
- static bool PresentInferenceResult(const std::vector<arm::app::object_detection::DetectionResult>& results)
+ static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results)
{
hal_lcd_set_text_color(COLOR_GREEN);
@@ -186,7 +184,7 @@ namespace app {
return true;
}
- static void DrawDetectionBoxes(const std::vector<arm::app::object_detection::DetectionResult>& results,
+ static void DrawDetectionBoxes(const std::vector<object_detection::DetectionResult>& results,
uint32_t imgStartX,
uint32_t imgStartY,
uint32_t imgDownscaleFactor)