diff options
author | Richard Burton <richard.burton@arm.com> | 2022-05-17 12:52:50 +0100 |
---|---|---|
committer | Richard Burton <richard.burton@arm.com> | 2022-05-20 11:08:32 +0100 |
commit | 6f6df0934f991b64fef494b86643b3f5074fca0e (patch) | |
tree | 3f04c50c7ee3bcaea4fa4e9c64c81b27d7cfc4fa /source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp | |
parent | 8c61c0a3cb8d6b534d1e423211e06b89f45bf223 (diff) | |
download | ml-embedded-evaluation-kit-6f6df0934f991b64fef494b86643b3f5074fca0e.tar.gz |
Remove dependency on extern defined constants from OD use case
OD API now takes in these paramaters as part of the constructor
Change-Id: I4cce25e364b2a99847b4540440db059997f6a81b
Diffstat (limited to 'source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp')
-rw-r--r-- | source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp | 47 |
1 files changed, 22 insertions, 25 deletions
diff --git a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp index 6a53688..b66edbf 100644 --- a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp +++ b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp @@ -28,6 +28,18 @@ namespace arm { namespace app { namespace object_detection { + struct PostProcessParams { + int inputImgRows{}; + int inputImgCols{}; + int originalImageSize{}; + const float* anchor1; + const float* anchor2; + float threshold = 0.5f; + float nms = 0.45f; + int numClasses = 1; + int topN = 0; + }; + struct Branch { int resolution; int numBox; @@ -57,25 +69,15 @@ namespace object_detection { public: /** * @brief Constructor. - * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0. - * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1. - * @param[out] results Vector of detected results. - * @param[in] inputImgRows Number of rows in the input image. - * @param[in] inputImgCols Number of columns in the input image. - * @param[in] threshold Post-processing threshold. - * @param[in] nms Non-maximum Suppression threshold. - * @param[in] numClasses Number of classes. - * @param[in] topN Top N for each class. + * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0. + * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1. + * @param[out] results Vector of detected results. + * @param[in] postProcessParams Struct of various parameters used in post-processing. **/ explicit DetectorPostProcess(TfLiteTensor* outputTensor0, TfLiteTensor* outputTensor1, std::vector<object_detection::DetectionResult>& results, - int inputImgRows, - int inputImgCols, - float threshold = 0.5f, - float nms = 0.45f, - int numClasses = 1, - int topN = 0); + const object_detection::PostProcessParams& postProcessParams); /** * @brief Should perform YOLO post-processing of the result of inference then @@ -85,16 +87,11 @@ namespace object_detection { bool DoPostProcess() override; private: - TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */ - TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */ - std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */ - int m_inputImgRows; /* Number of rows for model input. */ - int m_inputImgCols; /* Number of cols for model input. */ - float m_threshold; /* Post-processing threshold. */ - float m_nms; /* NMS threshold. */ - int m_numClasses; /* Number of classes. */ - int m_topN; /* TopN. */ - object_detection::Network m_net; /* YOLO network object. */ + TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */ + TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */ + std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */ + const object_detection::PostProcessParams& m_postProcessParams; /* Post processing param struct. */ + object_detection::Network m_net; /* YOLO network object. */ /** * @brief Insert the given Detection in the list. |