diff options
Diffstat (limited to 'source/application/api/use_case/object_detection/src')
-rw-r--r-- | source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc | 42 |
1 files changed, 16 insertions, 26 deletions
diff --git a/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc b/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc index 7610c4f..f555fbb 100644 --- a/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc +++ b/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc @@ -26,31 +26,21 @@ namespace app { TfLiteTensor* modelOutput0, TfLiteTensor* modelOutput1, std::vector<object_detection::DetectionResult>& results, - int inputImgRows, - int inputImgCols, - const float threshold, - const float nms, - int numClasses, - int topN) + const object_detection::PostProcessParams& postProcessParams) : m_outputTensor0{modelOutput0}, m_outputTensor1{modelOutput1}, m_results{results}, - m_inputImgRows{inputImgRows}, - m_inputImgCols{inputImgCols}, - m_threshold(threshold), - m_nms(nms), - m_numClasses(numClasses), - m_topN(topN) + m_postProcessParams{postProcessParams} { /* Init PostProcessing */ this->m_net = object_detection::Network{ - .inputWidth = inputImgCols, - .inputHeight = inputImgRows, - .numClasses = numClasses, + .inputWidth = postProcessParams.inputImgCols, + .inputHeight = postProcessParams.inputImgRows, + .numClasses = postProcessParams.numClasses, .branches = - {object_detection::Branch{.resolution = inputImgCols / 32, + {object_detection::Branch{.resolution = postProcessParams.inputImgCols / 32, .numBox = 3, - .anchor = arm::app::object_detection::anchor1, + .anchor = postProcessParams.anchor1, .modelOutput = this->m_outputTensor0->data.int8, .scale = (static_cast<TfLiteAffineQuantization*>( this->m_outputTensor0->quantization.params)) @@ -59,9 +49,9 @@ namespace app { this->m_outputTensor0->quantization.params)) ->zero_point->data[0], .size = this->m_outputTensor0->bytes}, - object_detection::Branch{.resolution = inputImgCols / 16, + object_detection::Branch{.resolution = postProcessParams.inputImgCols / 16, .numBox = 3, - .anchor = arm::app::object_detection::anchor2, + .anchor = postProcessParams.anchor2, .modelOutput = this->m_outputTensor1->data.int8, .scale = (static_cast<TfLiteAffineQuantization*>( this->m_outputTensor1->quantization.params)) @@ -70,21 +60,21 @@ namespace app { this->m_outputTensor1->quantization.params)) ->zero_point->data[0], .size = this->m_outputTensor1->bytes}}, - .topN = m_topN}; + .topN = postProcessParams.topN}; /* End init */ } bool DetectorPostProcess::DoPostProcess() { /* Start postprocessing */ - int originalImageWidth = arm::app::object_detection::originalImageSize; - int originalImageHeight = arm::app::object_detection::originalImageSize; + int originalImageWidth = m_postProcessParams.originalImageSize; + int originalImageHeight = m_postProcessParams.originalImageSize; std::forward_list<image::Detection> detections; - GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_threshold, detections); + GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_postProcessParams.threshold, detections); /* Do nms */ - CalculateNMS(detections, this->m_net.numClasses, m_nms); + CalculateNMS(detections, this->m_net.numClasses, this->m_postProcessParams.nms); for (auto& it: detections) { float xMin = it.bbox.x - it.bbox.w / 2.0f; @@ -219,10 +209,10 @@ void DetectorPostProcess::GetNetworkBoxes( num += 1; } else if (num == net.topN) { detections.sort(det_objectness_comparator); - InsertTopNDetections(detections,det); + InsertTopNDetections(detections, det); num += 1; } else { - InsertTopNDetections(detections,det); + InsertTopNDetections(detections, det); } } } |