From 6f6df0934f991b64fef494b86643b3f5074fca0e Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Tue, 17 May 2022 12:52:50 +0100 Subject: Remove dependency on extern defined constants from OD use case OD API now takes in these paramaters as part of the constructor Change-Id: I4cce25e364b2a99847b4540440db059997f6a81b --- .../include/DetectorPostProcessing.hpp | 47 ++++++++++------------ .../object_detection/src/DetectorPostProcessing.cc | 42 ++++++++----------- .../object_detection/src/UseCaseHandler.cc | 9 +++-- .../object_detection/InferenceTestYoloFastest.cc | 6 ++- 4 files changed, 48 insertions(+), 56 deletions(-) diff --git a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp index 6a53688..b66edbf 100644 --- a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp +++ b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp @@ -28,6 +28,18 @@ namespace arm { namespace app { namespace object_detection { + struct PostProcessParams { + int inputImgRows{}; + int inputImgCols{}; + int originalImageSize{}; + const float* anchor1; + const float* anchor2; + float threshold = 0.5f; + float nms = 0.45f; + int numClasses = 1; + int topN = 0; + }; + struct Branch { int resolution; int numBox; @@ -57,25 +69,15 @@ namespace object_detection { public: /** * @brief Constructor. - * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0. - * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1. - * @param[out] results Vector of detected results. - * @param[in] inputImgRows Number of rows in the input image. - * @param[in] inputImgCols Number of columns in the input image. - * @param[in] threshold Post-processing threshold. - * @param[in] nms Non-maximum Suppression threshold. - * @param[in] numClasses Number of classes. - * @param[in] topN Top N for each class. + * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0. + * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1. + * @param[out] results Vector of detected results. + * @param[in] postProcessParams Struct of various parameters used in post-processing. **/ explicit DetectorPostProcess(TfLiteTensor* outputTensor0, TfLiteTensor* outputTensor1, std::vector& results, - int inputImgRows, - int inputImgCols, - float threshold = 0.5f, - float nms = 0.45f, - int numClasses = 1, - int topN = 0); + const object_detection::PostProcessParams& postProcessParams); /** * @brief Should perform YOLO post-processing of the result of inference then @@ -85,16 +87,11 @@ namespace object_detection { bool DoPostProcess() override; private: - TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */ - TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */ - std::vector& m_results; /* Single inference results. */ - int m_inputImgRows; /* Number of rows for model input. */ - int m_inputImgCols; /* Number of cols for model input. */ - float m_threshold; /* Post-processing threshold. */ - float m_nms; /* NMS threshold. */ - int m_numClasses; /* Number of classes. */ - int m_topN; /* TopN. */ - object_detection::Network m_net; /* YOLO network object. */ + TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */ + TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */ + std::vector& m_results; /* Single inference results. */ + const object_detection::PostProcessParams& m_postProcessParams; /* Post processing param struct. */ + object_detection::Network m_net; /* YOLO network object. */ /** * @brief Insert the given Detection in the list. diff --git a/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc b/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc index 7610c4f..f555fbb 100644 --- a/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc +++ b/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc @@ -26,31 +26,21 @@ namespace app { TfLiteTensor* modelOutput0, TfLiteTensor* modelOutput1, std::vector& results, - int inputImgRows, - int inputImgCols, - const float threshold, - const float nms, - int numClasses, - int topN) + const object_detection::PostProcessParams& postProcessParams) : m_outputTensor0{modelOutput0}, m_outputTensor1{modelOutput1}, m_results{results}, - m_inputImgRows{inputImgRows}, - m_inputImgCols{inputImgCols}, - m_threshold(threshold), - m_nms(nms), - m_numClasses(numClasses), - m_topN(topN) + m_postProcessParams{postProcessParams} { /* Init PostProcessing */ this->m_net = object_detection::Network{ - .inputWidth = inputImgCols, - .inputHeight = inputImgRows, - .numClasses = numClasses, + .inputWidth = postProcessParams.inputImgCols, + .inputHeight = postProcessParams.inputImgRows, + .numClasses = postProcessParams.numClasses, .branches = - {object_detection::Branch{.resolution = inputImgCols / 32, + {object_detection::Branch{.resolution = postProcessParams.inputImgCols / 32, .numBox = 3, - .anchor = arm::app::object_detection::anchor1, + .anchor = postProcessParams.anchor1, .modelOutput = this->m_outputTensor0->data.int8, .scale = (static_cast( this->m_outputTensor0->quantization.params)) @@ -59,9 +49,9 @@ namespace app { this->m_outputTensor0->quantization.params)) ->zero_point->data[0], .size = this->m_outputTensor0->bytes}, - object_detection::Branch{.resolution = inputImgCols / 16, + object_detection::Branch{.resolution = postProcessParams.inputImgCols / 16, .numBox = 3, - .anchor = arm::app::object_detection::anchor2, + .anchor = postProcessParams.anchor2, .modelOutput = this->m_outputTensor1->data.int8, .scale = (static_cast( this->m_outputTensor1->quantization.params)) @@ -70,21 +60,21 @@ namespace app { this->m_outputTensor1->quantization.params)) ->zero_point->data[0], .size = this->m_outputTensor1->bytes}}, - .topN = m_topN}; + .topN = postProcessParams.topN}; /* End init */ } bool DetectorPostProcess::DoPostProcess() { /* Start postprocessing */ - int originalImageWidth = arm::app::object_detection::originalImageSize; - int originalImageHeight = arm::app::object_detection::originalImageSize; + int originalImageWidth = m_postProcessParams.originalImageSize; + int originalImageHeight = m_postProcessParams.originalImageSize; std::forward_list detections; - GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_threshold, detections); + GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_postProcessParams.threshold, detections); /* Do nms */ - CalculateNMS(detections, this->m_net.numClasses, m_nms); + CalculateNMS(detections, this->m_net.numClasses, this->m_postProcessParams.nms); for (auto& it: detections) { float xMin = it.bbox.x - it.bbox.w / 2.0f; @@ -219,10 +209,10 @@ void DetectorPostProcess::GetNetworkBoxes( num += 1; } else if (num == net.topN) { detections.sort(det_objectness_comparator); - InsertTopNDetections(detections,det); + InsertTopNDetections(detections, det); num += 1; } else { - InsertTopNDetections(detections,det); + InsertTopNDetections(detections, det); } } } diff --git a/source/use_case/object_detection/src/UseCaseHandler.cc b/source/use_case/object_detection/src/UseCaseHandler.cc index e9bcd4a..a7acb46 100644 --- a/source/use_case/object_detection/src/UseCaseHandler.cc +++ b/source/use_case/object_detection/src/UseCaseHandler.cc @@ -27,9 +27,6 @@ namespace arm { namespace app { - namespace object_detection { - extern const int channelsImageDisplayed; - } /* namespace object_detection */ /** * @brief Presents inference results along using the data presentation @@ -102,8 +99,12 @@ namespace app { DetectorPreProcess preProcess = DetectorPreProcess(inputTensor, true, model.IsDataSigned()); std::vector results; + const object_detection::PostProcessParams postProcessParams { + inputImgRows, inputImgCols, object_detection::originalImageSize, + object_detection::anchor1, object_detection::anchor2 + }; DetectorPostProcess postProcess = DetectorPostProcess(outputTensor0, outputTensor1, - results, inputImgRows, inputImgCols); + results, postProcessParams); do { /* Ensure there are no results leftover from previous inference when running all. */ results.clear(); diff --git a/tests/use_case/object_detection/InferenceTestYoloFastest.cc b/tests/use_case/object_detection/InferenceTestYoloFastest.cc index eb92904..d328684 100644 --- a/tests/use_case/object_detection/InferenceTestYoloFastest.cc +++ b/tests/use_case/object_detection/InferenceTestYoloFastest.cc @@ -104,7 +104,11 @@ void TestInferenceDetectionResults(int imageIdx, arm::app::Model& model, T toler REQUIRE(tflite::GetTensorData(output_arr[i])); } - arm::app::DetectorPostProcess postp{output_arr[0], output_arr[1], results, nRows, nCols}; + const arm::app::object_detection::PostProcessParams postProcessParams { + nRows, nCols, arm::app::object_detection::originalImageSize, + arm::app::object_detection::anchor1, arm::app::object_detection::anchor2 + }; + arm::app::DetectorPostProcess postp{output_arr[0], output_arr[1], results, postProcessParams}; postp.DoPostProcess(); std::vector> expected_results; -- cgit v1.2.1