summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-27 17:24:36 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-27 17:24:36 +0100
commitef90497eccf48bd725a96eb79e062ebfd4e2d618 (patch)
tree055ac754d4f057baead14f723ac759cf0136a3c7
parentb40ecf8522052809d2351677a96195d69e4d0c16 (diff)
downloadml-embedded-evaluation-kit-ef90497eccf48bd725a96eb79e062ebfd4e2d618.tar.gz
MLECO-3076: Add use case API for object detection
* Removed unused prototype for box drawing Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I1b03b88e710a5efb1ff8e107859d2245b1fead26
-rw-r--r--source/use_case/object_detection/include/DetectorPostProcessing.hpp88
-rw-r--r--source/use_case/object_detection/include/DetectorPreProcessing.hpp60
-rw-r--r--source/use_case/object_detection/src/DetectorPostProcessing.cc123
-rw-r--r--source/use_case/object_detection/src/DetectorPreProcessing.cc52
-rw-r--r--source/use_case/object_detection/src/MainLoop.cc4
-rw-r--r--source/use_case/object_detection/src/UseCaseHandler.cc74
-rw-r--r--tests/use_case/object_detection/InferenceTestYoloFastest.cc9
-rw-r--r--tests/use_case/object_detection/ObjectDetectionUCTest.cc4
8 files changed, 262 insertions, 152 deletions
diff --git a/source/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/use_case/object_detection/include/DetectorPostProcessing.hpp
index cdb14f5..b3ddb2c 100644
--- a/source/use_case/object_detection/include/DetectorPostProcessing.hpp
+++ b/source/use_case/object_detection/include/DetectorPostProcessing.hpp
@@ -21,11 +21,13 @@
#include "ImageUtils.hpp"
#include "DetectionResult.hpp"
#include "YoloFastestModel.hpp"
+#include "BaseProcessing.hpp"
#include <forward_list>
namespace arm {
namespace app {
+
namespace object_detection {
struct Branch {
@@ -46,42 +48,55 @@ namespace object_detection {
int topN;
};
+} /* namespace object_detection */
+
/**
- * @brief Helper class to manage tensor post-processing for "object_detection"
- * output.
+ * @brief Post-processing class for Object Detection use case.
+ * Implements methods declared by BasePostProcess and anything else needed
+ * to populate result vector.
*/
- class DetectorPostprocessing {
+ class DetectorPostProcess : public BasePostProcess {
public:
/**
- * @brief Constructor.
- * @param[in] threshold Post-processing threshold.
- * @param[in] nms Non-maximum Suppression threshold.
- * @param[in] numClasses Number of classes.
- * @param[in] topN Top N for each class.
+ * @brief Constructor.
+ * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
+ * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
+ * @param[out] results Vector of detected results.
+ * @param[in] inputImgRows Number of rows in the input image.
+ * @param[in] inputImgCols Number of columns in the input image.
+ * @param[in] threshold Post-processing threshold.
+ * @param[in] nms Non-maximum Suppression threshold.
+ * @param[in] numClasses Number of classes.
+ * @param[in] topN Top N for each class.
**/
- explicit DetectorPostprocessing(float threshold = 0.5f,
- float nms = 0.45f,
- int numClasses = 1,
- int topN = 0);
+ explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
+ TfLiteTensor* outputTensor1,
+ std::vector<object_detection::DetectionResult>& results,
+ int inputImgRows,
+ int inputImgCols,
+ float threshold = 0.5f,
+ float nms = 0.45f,
+ int numClasses = 1,
+ int topN = 0);
/**
- * @brief Post processing part of YOLO object detection CNN.
- * @param[in] imgRows Number of rows in the input image.
- * @param[in] imgCols Number of columns in the input image.
- * @param[in] modelOutput Output tensors after CNN invoked.
- * @param[out] resultsOut Vector of detected results.
+ * @brief Should perform YOLO post-processing of the result of inference then
+ * populate Detection result data for any later use.
+ * @return true if successful, false otherwise.
**/
- void RunPostProcessing(uint32_t imgRows,
- uint32_t imgCols,
- TfLiteTensor* modelOutput0,
- TfLiteTensor* modelOutput1,
- std::vector<DetectionResult>& resultsOut);
+ bool DoPostProcess() override;
private:
- float m_threshold; /* Post-processing threshold */
- float m_nms; /* NMS threshold */
- int m_numClasses; /* Number of classes */
- int m_topN; /* TopN */
+ TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
+ TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
+ std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
+ int m_inputImgRows; /* Number of rows for model input. */
+ int m_inputImgCols; /* Number of cols for model input. */
+ float m_threshold; /* Post-processing threshold. */
+ float m_nms; /* NMS threshold. */
+ int m_numClasses; /* Number of classes. */
+ int m_topN; /* TopN. */
+ object_detection::Network m_net; /* YOLO network object. */
/**
* @brief Insert the given Detection in the list.
@@ -98,32 +113,13 @@ namespace object_detection {
* @param[in] threshold Detections threshold.
* @param[out] detections Detection boxes.
**/
- void GetNetworkBoxes(Network& net,
+ void GetNetworkBoxes(object_detection::Network& net,
int imageWidth,
int imageHeight,
float threshold,
std::forward_list<image::Detection>& detections);
-
- /**
- * @brief Draw on the given image a bounding box starting at (boxX, boxY).
- * @param[in/out] imgIn Image.
- * @param[in] imWidth Image width.
- * @param[in] imHeight Image height.
- * @param[in] boxX Axis X starting point.
- * @param[in] boxY Axis Y starting point.
- * @param[in] boxWidth Box width.
- * @param[in] boxHeight Box height.
- **/
- void DrawBoxOnImage(uint8_t* imgIn,
- int imWidth,
- int imHeight,
- int boxX,
- int boxY,
- int boxWidth,
- int boxHeight);
};
-} /* namespace object_detection */
} /* namespace app */
} /* namespace arm */
diff --git a/source/use_case/object_detection/include/DetectorPreProcessing.hpp b/source/use_case/object_detection/include/DetectorPreProcessing.hpp
new file mode 100644
index 0000000..4936048
--- /dev/null
+++ b/source/use_case/object_detection/include/DetectorPreProcessing.hpp
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef DETECTOR_PRE_PROCESSING_HPP
+#define DETECTOR_PRE_PROCESSING_HPP
+
+#include "BaseProcessing.hpp"
+#include "Classifier.hpp"
+
+namespace arm {
+namespace app {
+
+ /**
+ * @brief Pre-processing class for Object detection use case.
+ * Implements methods declared by BasePreProcess and anything else needed
+ * to populate input tensors ready for inference.
+ */
+ class DetectorPreProcess : public BasePreProcess {
+
+ public:
+ /**
+ * @brief Constructor
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] rgb2Gray Convert image from 3 channel RGB to 1 channel grayscale.
+ * @param[in] convertToInt8 Convert the image from uint8 to int8 range.
+ **/
+ explicit DetectorPreProcess(TfLiteTensor* inputTensor, bool rgb2Gray, bool convertToInt8);
+
+ /**
+ * @brief Should perform pre-processing of 'raw' input image data and load it into
+ * TFLite Micro input tensor ready for inference
+ * @param[in] input Pointer to the data that pre-processing will work on.
+ * @param[in] inputSize Size of the input data.
+ * @return true if successful, false otherwise.
+ **/
+ bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ private:
+ TfLiteTensor* m_inputTensor;
+ bool m_rgb2Gray;
+ bool m_convertToInt8;
+ };
+
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* DETECTOR_PRE_PROCESSING_HPP */ \ No newline at end of file
diff --git a/source/use_case/object_detection/src/DetectorPostProcessing.cc b/source/use_case/object_detection/src/DetectorPostProcessing.cc
index a890c9e..fb1606a 100644
--- a/source/use_case/object_detection/src/DetectorPostProcessing.cc
+++ b/source/use_case/object_detection/src/DetectorPostProcessing.cc
@@ -21,64 +21,73 @@
namespace arm {
namespace app {
-namespace object_detection {
-
-DetectorPostprocessing::DetectorPostprocessing(
- const float threshold,
- const float nms,
- int numClasses,
- int topN)
- : m_threshold(threshold),
- m_nms(nms),
- m_numClasses(numClasses),
- m_topN(topN)
-{}
-
-void DetectorPostprocessing::RunPostProcessing(
- uint32_t imgRows,
- uint32_t imgCols,
- TfLiteTensor* modelOutput0,
- TfLiteTensor* modelOutput1,
- std::vector<DetectionResult>& resultsOut)
+
+ DetectorPostProcess::DetectorPostProcess(
+ TfLiteTensor* modelOutput0,
+ TfLiteTensor* modelOutput1,
+ std::vector<object_detection::DetectionResult>& results,
+ int inputImgRows,
+ int inputImgCols,
+ const float threshold,
+ const float nms,
+ int numClasses,
+ int topN)
+ : m_outputTensor0{modelOutput0},
+ m_outputTensor1{modelOutput1},
+ m_results{results},
+ m_inputImgRows{inputImgRows},
+ m_inputImgCols{inputImgCols},
+ m_threshold(threshold),
+ m_nms(nms),
+ m_numClasses(numClasses),
+ m_topN(topN)
{
- /* init postprocessing */
- Network net {
- .inputWidth = static_cast<int>(imgCols),
- .inputHeight = static_cast<int>(imgRows),
- .numClasses = m_numClasses,
+ /* Init PostProcessing */
+ this->m_net =
+ object_detection::Network {
+ .inputWidth = inputImgCols,
+ .inputHeight = inputImgRows,
+ .numClasses = numClasses,
.branches = {
- Branch {
- .resolution = static_cast<int>(imgCols/32),
- .numBox = 3,
- .anchor = anchor1,
- .modelOutput = modelOutput0->data.int8,
- .scale = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->scale->data[0],
- .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->zero_point->data[0],
- .size = modelOutput0->bytes
+ object_detection::Branch {
+ .resolution = inputImgCols/32,
+ .numBox = 3,
+ .anchor = anchor1,
+ .modelOutput = this->m_outputTensor0->data.int8,
+ .scale = (static_cast<TfLiteAffineQuantization*>(
+ this->m_outputTensor0->quantization.params))->scale->data[0],
+ .zeroPoint = (static_cast<TfLiteAffineQuantization*>(
+ this->m_outputTensor0->quantization.params))->zero_point->data[0],
+ .size = this->m_outputTensor0->bytes
},
- Branch {
- .resolution = static_cast<int>(imgCols/16),
- .numBox = 3,
- .anchor = anchor2,
- .modelOutput = modelOutput1->data.int8,
- .scale = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->scale->data[0],
- .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->zero_point->data[0],
- .size = modelOutput1->bytes
+ object_detection::Branch {
+ .resolution = inputImgCols/16,
+ .numBox = 3,
+ .anchor = anchor2,
+ .modelOutput = this->m_outputTensor1->data.int8,
+ .scale = (static_cast<TfLiteAffineQuantization*>(
+ this->m_outputTensor1->quantization.params))->scale->data[0],
+ .zeroPoint = (static_cast<TfLiteAffineQuantization*>(
+ this->m_outputTensor1->quantization.params))->zero_point->data[0],
+ .size = this->m_outputTensor1->bytes
}
},
.topN = m_topN
};
/* End init */
+}
+bool DetectorPostProcess::DoPostProcess()
+{
/* Start postprocessing */
int originalImageWidth = originalImageSize;
int originalImageHeight = originalImageSize;
std::forward_list<image::Detection> detections;
- GetNetworkBoxes(net, originalImageWidth, originalImageHeight, m_threshold, detections);
+ GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_threshold, detections);
/* Do nms */
- CalculateNMS(detections, net.numClasses, m_nms);
+ CalculateNMS(detections, this->m_net.numClasses, m_nms);
for (auto& it: detections) {
float xMin = it.bbox.x - it.bbox.w / 2.0f;
@@ -104,24 +113,24 @@ void DetectorPostprocessing::RunPostProcessing(
float boxWidth = xMax - xMin;
float boxHeight = yMax - yMin;
- for (int j = 0; j < net.numClasses; ++j) {
+ for (int j = 0; j < this->m_net.numClasses; ++j) {
if (it.prob[j] > 0) {
- DetectionResult tmpResult = {};
+ object_detection::DetectionResult tmpResult = {};
tmpResult.m_normalisedVal = it.prob[j];
tmpResult.m_x0 = boxX;
tmpResult.m_y0 = boxY;
tmpResult.m_w = boxWidth;
tmpResult.m_h = boxHeight;
- resultsOut.push_back(tmpResult);
+ this->m_results.push_back(tmpResult);
}
}
}
+ return true;
}
-
-void DetectorPostprocessing::InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det)
+void DetectorPostProcess::InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det)
{
std::forward_list<image::Detection>::iterator it;
std::forward_list<image::Detection>::iterator last_it;
@@ -136,7 +145,12 @@ void DetectorPostprocessing::InsertTopNDetections(std::forward_list<image::Detec
}
}
-void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int imageHeight, float threshold, std::forward_list<image::Detection>& detections)
+void DetectorPostProcess::GetNetworkBoxes(
+ object_detection::Network& net,
+ int imageWidth,
+ int imageHeight,
+ float threshold,
+ std::forward_list<image::Detection>& detections)
{
int numClasses = net.numClasses;
int num = 0;
@@ -169,10 +183,14 @@ void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int i
int bbox_h_offset = bbox_x_offset + 3;
int bbox_scores_offset = bbox_x_offset + 5;
- det.bbox.x = (static_cast<float>(net.branches[i].modelOutput[bbox_x_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
- det.bbox.y = (static_cast<float>(net.branches[i].modelOutput[bbox_y_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
- det.bbox.w = (static_cast<float>(net.branches[i].modelOutput[bbox_w_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
- det.bbox.h = (static_cast<float>(net.branches[i].modelOutput[bbox_h_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
+ det.bbox.x = (static_cast<float>(net.branches[i].modelOutput[bbox_x_offset])
+ - net.branches[i].zeroPoint) * net.branches[i].scale;
+ det.bbox.y = (static_cast<float>(net.branches[i].modelOutput[bbox_y_offset])
+ - net.branches[i].zeroPoint) * net.branches[i].scale;
+ det.bbox.w = (static_cast<float>(net.branches[i].modelOutput[bbox_w_offset])
+ - net.branches[i].zeroPoint) * net.branches[i].scale;
+ det.bbox.h = (static_cast<float>(net.branches[i].modelOutput[bbox_h_offset])
+ - net.branches[i].zeroPoint) * net.branches[i].scale;
float bbox_x, bbox_y;
@@ -218,6 +236,5 @@ void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int i
num -=1;
}
-} /* namespace object_detection */
} /* namespace app */
} /* namespace arm */
diff --git a/source/use_case/object_detection/src/DetectorPreProcessing.cc b/source/use_case/object_detection/src/DetectorPreProcessing.cc
new file mode 100644
index 0000000..7212046
--- /dev/null
+++ b/source/use_case/object_detection/src/DetectorPreProcessing.cc
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "DetectorPreProcessing.hpp"
+#include "ImageUtils.hpp"
+#include "log_macros.h"
+
+namespace arm {
+namespace app {
+
+ DetectorPreProcess::DetectorPreProcess(TfLiteTensor* inputTensor, bool rgb2Gray, bool convertToInt8)
+ : m_inputTensor{inputTensor},
+ m_rgb2Gray{rgb2Gray},
+ m_convertToInt8{convertToInt8}
+ {}
+
+ bool DetectorPreProcess::DoPreProcess(const void* data, size_t inputSize) {
+ if (data == nullptr) {
+ printf_err("Data pointer is null");
+ }
+
+ auto input = static_cast<const uint8_t*>(data);
+
+ if (this->m_rgb2Gray) {
+ image::RgbToGrayscale(input, this->m_inputTensor->data.uint8, this->m_inputTensor->bytes);
+ } else {
+ std::memcpy(this->m_inputTensor->data.data, input, inputSize);
+ }
+ debug("Input tensor populated \n");
+
+ if (this->m_convertToInt8) {
+ image::ConvertImgToInt8(this->m_inputTensor->data.data, this->m_inputTensor->bytes);
+ }
+
+ return true;
+ }
+
+} /* namespace app */
+} /* namespace arm */ \ No newline at end of file
diff --git a/source/use_case/object_detection/src/MainLoop.cc b/source/use_case/object_detection/src/MainLoop.cc
index acfc195..4291164 100644
--- a/source/use_case/object_detection/src/MainLoop.cc
+++ b/source/use_case/object_detection/src/MainLoop.cc
@@ -19,7 +19,6 @@
#include "YoloFastestModel.hpp" /* Model class for running inference. */
#include "UseCaseHandler.hpp" /* Handlers for different user options. */
#include "UseCaseCommonUtils.hpp" /* Utils functions. */
-#include "DetectorPostProcessing.hpp" /* Post-processing class. */
#include "log_macros.h"
static void DisplayDetectionMenu()
@@ -53,9 +52,6 @@ void main_loop()
caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("imgIndex", 0);
- arm::app::object_detection::DetectorPostprocessing postp;
- caseContext.Set<arm::app::object_detection::DetectorPostprocessing&>("postprocess", postp);
-
/* Loop. */
bool executionSuccessful = true;
diff --git a/source/use_case/object_detection/src/UseCaseHandler.cc b/source/use_case/object_detection/src/UseCaseHandler.cc
index f3b317e..332d199 100644
--- a/source/use_case/object_detection/src/UseCaseHandler.cc
+++ b/source/use_case/object_detection/src/UseCaseHandler.cc
@@ -19,6 +19,7 @@
#include "YoloFastestModel.hpp"
#include "UseCaseCommonUtils.hpp"
#include "DetectorPostProcessing.hpp"
+#include "DetectorPreProcessing.hpp"
#include "hal.h"
#include "log_macros.h"
@@ -33,7 +34,7 @@ namespace app {
* @param[in] results Vector of detection results to be displayed.
* @return true if successful, false otherwise.
**/
- static bool PresentInferenceResult(const std::vector<arm::app::object_detection::DetectionResult>& results);
+ static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results);
/**
* @brief Draw boxes directly on the LCD for all detected objects.
@@ -43,12 +44,12 @@ namespace app {
* @param[in] imgDownscaleFactor How much image has been downscaled on LCD.
**/
static void DrawDetectionBoxes(
- const std::vector<arm::app::object_detection::DetectionResult>& results,
+ const std::vector<object_detection::DetectionResult>& results,
uint32_t imgStartX,
uint32_t imgStartY,
uint32_t imgDownscaleFactor);
- /* Object detection classification handler. */
+ /* Object detection inference handler. */
bool ObjectDetectionHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll)
{
auto& profiler = ctx.Get<Profiler&>("profiler");
@@ -75,9 +76,11 @@ namespace app {
return false;
}
- auto curImIdx = ctx.Get<uint32_t>("imgIndex");
+ auto initialImgIdx = ctx.Get<uint32_t>("imgIndex");
TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor0 = model.GetOutputTensor(0);
+ TfLiteTensor* outputTensor1 = model.GetOutputTensor(1);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
@@ -89,71 +92,66 @@ namespace app {
TfLiteIntArray* inputShape = model.GetInputShape(0);
- const uint32_t nCols = inputShape->data[arm::app::YoloFastestModel::ms_inputColsIdx];
- const uint32_t nRows = inputShape->data[arm::app::YoloFastestModel::ms_inputRowsIdx];
+ const int inputImgCols = inputShape->data[YoloFastestModel::ms_inputColsIdx];
+ const int inputImgRows = inputShape->data[YoloFastestModel::ms_inputRowsIdx];
- /* Get pre/post-processing objects. */
- auto& postp = ctx.Get<object_detection::DetectorPostprocessing&>("postprocess");
+ /* Set up pre and post-processing. */
+ DetectorPreProcess preProcess = DetectorPreProcess(inputTensor, true, model.IsDataSigned());
+ std::vector<object_detection::DetectionResult> results;
+ DetectorPostProcess postProcess = DetectorPostProcess(outputTensor0, outputTensor1,
+ results, inputImgRows, inputImgCols);
do {
/* Strings for presentation/logging. */
std::string str_inf{"Running inference... "};
- const uint8_t* curr_image = get_img_array(ctx.Get<uint32_t>("imgIndex"));
+ const uint8_t* currImage = get_img_array(ctx.Get<uint32_t>("imgIndex"));
- /* Copy over the data and convert to grayscale */
- auto* dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8);
+ auto dstPtr = static_cast<uint8_t*>(inputTensor->data.uint8);
const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ?
inputTensor->bytes : IMAGE_DATA_SIZE;
- /* Convert to gray scale and populate input tensor. */
- image::RgbToGrayscale(curr_image, dstPtr, copySz);
+ /* Run the pre-processing, inference and post-processing. */
+ if (!preProcess.DoPreProcess(currImage, copySz)) {
+ printf_err("Pre-processing failed.");
+ return false;
+ }
/* Display image on the LCD. */
hal_lcd_display_image(
- (channelsImageDisplayed == 3) ? curr_image : dstPtr,
- nCols, nRows, channelsImageDisplayed,
+ (channelsImageDisplayed == 3) ? currImage : dstPtr,
+ inputImgCols, inputImgRows, channelsImageDisplayed,
dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
- /* If the data is signed. */
- if (model.IsDataSigned()) {
- image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
- }
-
/* Display message on the LCD - inference running. */
hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
/* Run inference over this image. */
info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
get_filename(ctx.Get<uint32_t>("imgIndex")));
if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
+ return false;
+ }
+
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
-
- /* Detector post-processing*/
- std::vector<object_detection::DetectionResult> results;
- TfLiteTensor* modelOutput0 = model.GetOutputTensor(0);
- TfLiteTensor* modelOutput1 = model.GetOutputTensor(1);
- postp.RunPostProcessing(
- nRows,
- nCols,
- modelOutput0,
- modelOutput1,
- results);
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
/* Draw boxes. */
DrawDetectionBoxes(results, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor);
#if VERIFY_TEST_OUTPUT
- arm::app::DumpTensor(modelOutput0);
- arm::app::DumpTensor(modelOutput1);
+ DumpTensor(modelOutput0);
+ DumpTensor(modelOutput1);
#endif /* VERIFY_TEST_OUTPUT */
if (!PresentInferenceResult(results)) {
@@ -164,12 +162,12 @@ namespace app {
IncrementAppCtxIfmIdx(ctx,"imgIndex");
- } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx);
+ } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImgIdx);
return true;
}
- static bool PresentInferenceResult(const std::vector<arm::app::object_detection::DetectionResult>& results)
+ static bool PresentInferenceResult(const std::vector<object_detection::DetectionResult>& results)
{
hal_lcd_set_text_color(COLOR_GREEN);
@@ -186,7 +184,7 @@ namespace app {
return true;
}
- static void DrawDetectionBoxes(const std::vector<arm::app::object_detection::DetectionResult>& results,
+ static void DrawDetectionBoxes(const std::vector<object_detection::DetectionResult>& results,
uint32_t imgStartX,
uint32_t imgStartY,
uint32_t imgDownscaleFactor)
diff --git a/tests/use_case/object_detection/InferenceTestYoloFastest.cc b/tests/use_case/object_detection/InferenceTestYoloFastest.cc
index 8ef012d..2c035e7 100644
--- a/tests/use_case/object_detection/InferenceTestYoloFastest.cc
+++ b/tests/use_case/object_detection/InferenceTestYoloFastest.cc
@@ -94,13 +94,8 @@ void TestInferenceDetectionResults(int imageIdx, arm::app::Model& model, T toler
REQUIRE(tflite::GetTensorData<T>(output_arr[i]));
}
- arm::app::object_detection::DetectorPostprocessing postp;
- postp.RunPostProcessing(
- nRows,
- nCols,
- output_arr[0],
- output_arr[1],
- results);
+ arm::app::DetectorPostProcess postp{output_arr[0], output_arr[1], results, nRows, nCols};
+ postp.DoPostProcess();
std::vector<std::vector<arm::app::object_detection::DetectionResult>> expected_results;
GetExpectedResults(expected_results);
diff --git a/tests/use_case/object_detection/ObjectDetectionUCTest.cc b/tests/use_case/object_detection/ObjectDetectionUCTest.cc
index a7e4f33..023b893 100644
--- a/tests/use_case/object_detection/ObjectDetectionUCTest.cc
+++ b/tests/use_case/object_detection/ObjectDetectionUCTest.cc
@@ -58,8 +58,6 @@ TEST_CASE("Inference by index")
caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("imgIndex", 0);
- arm::app::object_detection::DetectorPostprocessing postp;
- caseContext.Set<arm::app::object_detection::DetectorPostprocessing&>("postprocess", postp);
REQUIRE(arm::app::ObjectDetectionHandler(caseContext, 0, false));
}
@@ -83,8 +81,6 @@ TEST_CASE("Inference run all images")
caseContext.Set<arm::app::Profiler&>("profiler", profiler);
caseContext.Set<arm::app::Model&>("model", model);
caseContext.Set<uint32_t>("imgIndex", 0);
- arm::app::object_detection::DetectorPostprocessing postp;
- caseContext.Set<arm::app::object_detection::DetectorPostprocessing&>("postprocess", postp);
REQUIRE(arm::app::ObjectDetectionHandler(caseContext, 0, true));
}