summaryrefslogtreecommitdiff
path: root/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp')
-rw-r--r--source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp47
1 files changed, 22 insertions, 25 deletions
diff --git a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp
index 6a53688..b66edbf 100644
--- a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp
+++ b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp
@@ -28,6 +28,18 @@ namespace arm {
namespace app {
namespace object_detection {
+ struct PostProcessParams {
+ int inputImgRows{};
+ int inputImgCols{};
+ int originalImageSize{};
+ const float* anchor1;
+ const float* anchor2;
+ float threshold = 0.5f;
+ float nms = 0.45f;
+ int numClasses = 1;
+ int topN = 0;
+ };
+
struct Branch {
int resolution;
int numBox;
@@ -57,25 +69,15 @@ namespace object_detection {
public:
/**
* @brief Constructor.
- * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
- * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
- * @param[out] results Vector of detected results.
- * @param[in] inputImgRows Number of rows in the input image.
- * @param[in] inputImgCols Number of columns in the input image.
- * @param[in] threshold Post-processing threshold.
- * @param[in] nms Non-maximum Suppression threshold.
- * @param[in] numClasses Number of classes.
- * @param[in] topN Top N for each class.
+ * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
+ * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
+ * @param[out] results Vector of detected results.
+ * @param[in] postProcessParams Struct of various parameters used in post-processing.
**/
explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
TfLiteTensor* outputTensor1,
std::vector<object_detection::DetectionResult>& results,
- int inputImgRows,
- int inputImgCols,
- float threshold = 0.5f,
- float nms = 0.45f,
- int numClasses = 1,
- int topN = 0);
+ const object_detection::PostProcessParams& postProcessParams);
/**
* @brief Should perform YOLO post-processing of the result of inference then
@@ -85,16 +87,11 @@ namespace object_detection {
bool DoPostProcess() override;
private:
- TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
- TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
- std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
- int m_inputImgRows; /* Number of rows for model input. */
- int m_inputImgCols; /* Number of cols for model input. */
- float m_threshold; /* Post-processing threshold. */
- float m_nms; /* NMS threshold. */
- int m_numClasses; /* Number of classes. */
- int m_topN; /* TopN. */
- object_detection::Network m_net; /* YOLO network object. */
+ TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
+ TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
+ std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
+ const object_detection::PostProcessParams& m_postProcessParams; /* Post processing param struct. */
+ object_detection::Network m_net; /* YOLO network object. */
/**
* @brief Insert the given Detection in the list.