summaryrefslogtreecommitdiff
path: root/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc')
-rw-r--r--source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc42
1 files changed, 16 insertions, 26 deletions
diff --git a/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc b/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc
index 7610c4f..f555fbb 100644
--- a/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc
+++ b/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc
@@ -26,31 +26,21 @@ namespace app {
TfLiteTensor* modelOutput0,
TfLiteTensor* modelOutput1,
std::vector<object_detection::DetectionResult>& results,
- int inputImgRows,
- int inputImgCols,
- const float threshold,
- const float nms,
- int numClasses,
- int topN)
+ const object_detection::PostProcessParams& postProcessParams)
: m_outputTensor0{modelOutput0},
m_outputTensor1{modelOutput1},
m_results{results},
- m_inputImgRows{inputImgRows},
- m_inputImgCols{inputImgCols},
- m_threshold(threshold),
- m_nms(nms),
- m_numClasses(numClasses),
- m_topN(topN)
+ m_postProcessParams{postProcessParams}
{
/* Init PostProcessing */
this->m_net = object_detection::Network{
- .inputWidth = inputImgCols,
- .inputHeight = inputImgRows,
- .numClasses = numClasses,
+ .inputWidth = postProcessParams.inputImgCols,
+ .inputHeight = postProcessParams.inputImgRows,
+ .numClasses = postProcessParams.numClasses,
.branches =
- {object_detection::Branch{.resolution = inputImgCols / 32,
+ {object_detection::Branch{.resolution = postProcessParams.inputImgCols / 32,
.numBox = 3,
- .anchor = arm::app::object_detection::anchor1,
+ .anchor = postProcessParams.anchor1,
.modelOutput = this->m_outputTensor0->data.int8,
.scale = (static_cast<TfLiteAffineQuantization*>(
this->m_outputTensor0->quantization.params))
@@ -59,9 +49,9 @@ namespace app {
this->m_outputTensor0->quantization.params))
->zero_point->data[0],
.size = this->m_outputTensor0->bytes},
- object_detection::Branch{.resolution = inputImgCols / 16,
+ object_detection::Branch{.resolution = postProcessParams.inputImgCols / 16,
.numBox = 3,
- .anchor = arm::app::object_detection::anchor2,
+ .anchor = postProcessParams.anchor2,
.modelOutput = this->m_outputTensor1->data.int8,
.scale = (static_cast<TfLiteAffineQuantization*>(
this->m_outputTensor1->quantization.params))
@@ -70,21 +60,21 @@ namespace app {
this->m_outputTensor1->quantization.params))
->zero_point->data[0],
.size = this->m_outputTensor1->bytes}},
- .topN = m_topN};
+ .topN = postProcessParams.topN};
/* End init */
}
bool DetectorPostProcess::DoPostProcess()
{
/* Start postprocessing */
- int originalImageWidth = arm::app::object_detection::originalImageSize;
- int originalImageHeight = arm::app::object_detection::originalImageSize;
+ int originalImageWidth = m_postProcessParams.originalImageSize;
+ int originalImageHeight = m_postProcessParams.originalImageSize;
std::forward_list<image::Detection> detections;
- GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_threshold, detections);
+ GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_postProcessParams.threshold, detections);
/* Do nms */
- CalculateNMS(detections, this->m_net.numClasses, m_nms);
+ CalculateNMS(detections, this->m_net.numClasses, this->m_postProcessParams.nms);
for (auto& it: detections) {
float xMin = it.bbox.x - it.bbox.w / 2.0f;
@@ -219,10 +209,10 @@ void DetectorPostProcess::GetNetworkBoxes(
num += 1;
} else if (num == net.topN) {
detections.sort(det_objectness_comparator);
- InsertTopNDetections(detections,det);
+ InsertTopNDetections(detections, det);
num += 1;
} else {
- InsertTopNDetections(detections,det);
+ InsertTopNDetections(detections, det);
}
}
}