summaryrefslogtreecommitdiff
path: root/source/use_case/object_detection/include/DetectorPostProcessing.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/object_detection/include/DetectorPostProcessing.hpp')
-rw-r--r--source/use_case/object_detection/include/DetectorPostProcessing.hpp88
1 files changed, 42 insertions, 46 deletions
diff --git a/source/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/use_case/object_detection/include/DetectorPostProcessing.hpp
index cdb14f5..b3ddb2c 100644
--- a/source/use_case/object_detection/include/DetectorPostProcessing.hpp
+++ b/source/use_case/object_detection/include/DetectorPostProcessing.hpp
@@ -21,11 +21,13 @@
#include "ImageUtils.hpp"
#include "DetectionResult.hpp"
#include "YoloFastestModel.hpp"
+#include "BaseProcessing.hpp"
#include <forward_list>
namespace arm {
namespace app {
+
namespace object_detection {
struct Branch {
@@ -46,42 +48,55 @@ namespace object_detection {
int topN;
};
+} /* namespace object_detection */
+
/**
- * @brief Helper class to manage tensor post-processing for "object_detection"
- * output.
+ * @brief Post-processing class for Object Detection use case.
+ * Implements methods declared by BasePostProcess and anything else needed
+ * to populate result vector.
*/
- class DetectorPostprocessing {
+ class DetectorPostProcess : public BasePostProcess {
public:
/**
- * @brief Constructor.
- * @param[in] threshold Post-processing threshold.
- * @param[in] nms Non-maximum Suppression threshold.
- * @param[in] numClasses Number of classes.
- * @param[in] topN Top N for each class.
+ * @brief Constructor.
+ * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
+ * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
+ * @param[out] results Vector of detected results.
+ * @param[in] inputImgRows Number of rows in the input image.
+ * @param[in] inputImgCols Number of columns in the input image.
+ * @param[in] threshold Post-processing threshold.
+ * @param[in] nms Non-maximum Suppression threshold.
+ * @param[in] numClasses Number of classes.
+ * @param[in] topN Top N for each class.
**/
- explicit DetectorPostprocessing(float threshold = 0.5f,
- float nms = 0.45f,
- int numClasses = 1,
- int topN = 0);
+ explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
+ TfLiteTensor* outputTensor1,
+ std::vector<object_detection::DetectionResult>& results,
+ int inputImgRows,
+ int inputImgCols,
+ float threshold = 0.5f,
+ float nms = 0.45f,
+ int numClasses = 1,
+ int topN = 0);
/**
- * @brief Post processing part of YOLO object detection CNN.
- * @param[in] imgRows Number of rows in the input image.
- * @param[in] imgCols Number of columns in the input image.
- * @param[in] modelOutput Output tensors after CNN invoked.
- * @param[out] resultsOut Vector of detected results.
+ * @brief Should perform YOLO post-processing of the result of inference then
+ * populate Detection result data for any later use.
+ * @return true if successful, false otherwise.
**/
- void RunPostProcessing(uint32_t imgRows,
- uint32_t imgCols,
- TfLiteTensor* modelOutput0,
- TfLiteTensor* modelOutput1,
- std::vector<DetectionResult>& resultsOut);
+ bool DoPostProcess() override;
private:
- float m_threshold; /* Post-processing threshold */
- float m_nms; /* NMS threshold */
- int m_numClasses; /* Number of classes */
- int m_topN; /* TopN */
+ TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
+ TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
+ std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
+ int m_inputImgRows; /* Number of rows for model input. */
+ int m_inputImgCols; /* Number of cols for model input. */
+ float m_threshold; /* Post-processing threshold. */
+ float m_nms; /* NMS threshold. */
+ int m_numClasses; /* Number of classes. */
+ int m_topN; /* TopN. */
+ object_detection::Network m_net; /* YOLO network object. */
/**
* @brief Insert the given Detection in the list.
@@ -98,32 +113,13 @@ namespace object_detection {
* @param[in] threshold Detections threshold.
* @param[out] detections Detection boxes.
**/
- void GetNetworkBoxes(Network& net,
+ void GetNetworkBoxes(object_detection::Network& net,
int imageWidth,
int imageHeight,
float threshold,
std::forward_list<image::Detection>& detections);
-
- /**
- * @brief Draw on the given image a bounding box starting at (boxX, boxY).
- * @param[in/out] imgIn Image.
- * @param[in] imWidth Image width.
- * @param[in] imHeight Image height.
- * @param[in] boxX Axis X starting point.
- * @param[in] boxY Axis Y starting point.
- * @param[in] boxWidth Box width.
- * @param[in] boxHeight Box height.
- **/
- void DrawBoxOnImage(uint8_t* imgIn,
- int imWidth,
- int imHeight,
- int boxX,
- int boxY,
- int boxWidth,
- int boxHeight);
};
-} /* namespace object_detection */
} /* namespace app */
} /* namespace arm */