diff options
Diffstat (limited to 'source/application/api/use_case/object_detection/include')
-rw-r--r-- | source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp | 47 |
1 files changed, 22 insertions, 25 deletions
diff --git a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp index 6a53688..b66edbf 100644 --- a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp +++ b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp @@ -28,6 +28,18 @@ namespace arm { namespace app { namespace object_detection { + struct PostProcessParams { + int inputImgRows{}; + int inputImgCols{}; + int originalImageSize{}; + const float* anchor1; + const float* anchor2; + float threshold = 0.5f; + float nms = 0.45f; + int numClasses = 1; + int topN = 0; + }; + struct Branch { int resolution; int numBox; @@ -57,25 +69,15 @@ namespace object_detection { public: /** * @brief Constructor. - * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0. - * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1. - * @param[out] results Vector of detected results. - * @param[in] inputImgRows Number of rows in the input image. - * @param[in] inputImgCols Number of columns in the input image. - * @param[in] threshold Post-processing threshold. - * @param[in] nms Non-maximum Suppression threshold. - * @param[in] numClasses Number of classes. - * @param[in] topN Top N for each class. + * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0. + * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1. + * @param[out] results Vector of detected results. + * @param[in] postProcessParams Struct of various parameters used in post-processing. **/ explicit DetectorPostProcess(TfLiteTensor* outputTensor0, TfLiteTensor* outputTensor1, std::vector<object_detection::DetectionResult>& results, - int inputImgRows, - int inputImgCols, - float threshold = 0.5f, - float nms = 0.45f, - int numClasses = 1, - int topN = 0); + const object_detection::PostProcessParams& postProcessParams); /** * @brief Should perform YOLO post-processing of the result of inference then @@ -85,16 +87,11 @@ namespace object_detection { bool DoPostProcess() override; private: - TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */ - TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */ - std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */ - int m_inputImgRows; /* Number of rows for model input. */ - int m_inputImgCols; /* Number of cols for model input. */ - float m_threshold; /* Post-processing threshold. */ - float m_nms; /* NMS threshold. */ - int m_numClasses; /* Number of classes. */ - int m_topN; /* TopN. */ - object_detection::Network m_net; /* YOLO network object. */ + TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */ + TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */ + std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */ + const object_detection::PostProcessParams& m_postProcessParams; /* Post processing param struct. */ + object_detection::Network m_net; /* YOLO network object. */ /** * @brief Insert the given Detection in the list. |