diff options
Diffstat (limited to 'source/use_case/object_detection/src/DetectorPostProcessing.cc')
-rw-r--r-- | source/use_case/object_detection/src/DetectorPostProcessing.cc | 123 |
1 files changed, 70 insertions, 53 deletions
diff --git a/source/use_case/object_detection/src/DetectorPostProcessing.cc b/source/use_case/object_detection/src/DetectorPostProcessing.cc index a890c9e..fb1606a 100644 --- a/source/use_case/object_detection/src/DetectorPostProcessing.cc +++ b/source/use_case/object_detection/src/DetectorPostProcessing.cc @@ -21,64 +21,73 @@ namespace arm { namespace app { -namespace object_detection { - -DetectorPostprocessing::DetectorPostprocessing( - const float threshold, - const float nms, - int numClasses, - int topN) - : m_threshold(threshold), - m_nms(nms), - m_numClasses(numClasses), - m_topN(topN) -{} - -void DetectorPostprocessing::RunPostProcessing( - uint32_t imgRows, - uint32_t imgCols, - TfLiteTensor* modelOutput0, - TfLiteTensor* modelOutput1, - std::vector<DetectionResult>& resultsOut) + + DetectorPostProcess::DetectorPostProcess( + TfLiteTensor* modelOutput0, + TfLiteTensor* modelOutput1, + std::vector<object_detection::DetectionResult>& results, + int inputImgRows, + int inputImgCols, + const float threshold, + const float nms, + int numClasses, + int topN) + : m_outputTensor0{modelOutput0}, + m_outputTensor1{modelOutput1}, + m_results{results}, + m_inputImgRows{inputImgRows}, + m_inputImgCols{inputImgCols}, + m_threshold(threshold), + m_nms(nms), + m_numClasses(numClasses), + m_topN(topN) { - /* init postprocessing */ - Network net { - .inputWidth = static_cast<int>(imgCols), - .inputHeight = static_cast<int>(imgRows), - .numClasses = m_numClasses, + /* Init PostProcessing */ + this->m_net = + object_detection::Network { + .inputWidth = inputImgCols, + .inputHeight = inputImgRows, + .numClasses = numClasses, .branches = { - Branch { - .resolution = static_cast<int>(imgCols/32), - .numBox = 3, - .anchor = anchor1, - .modelOutput = modelOutput0->data.int8, - .scale = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->scale->data[0], - .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->zero_point->data[0], - .size = modelOutput0->bytes + object_detection::Branch { + .resolution = inputImgCols/32, + .numBox = 3, + .anchor = anchor1, + .modelOutput = this->m_outputTensor0->data.int8, + .scale = (static_cast<TfLiteAffineQuantization*>( + this->m_outputTensor0->quantization.params))->scale->data[0], + .zeroPoint = (static_cast<TfLiteAffineQuantization*>( + this->m_outputTensor0->quantization.params))->zero_point->data[0], + .size = this->m_outputTensor0->bytes }, - Branch { - .resolution = static_cast<int>(imgCols/16), - .numBox = 3, - .anchor = anchor2, - .modelOutput = modelOutput1->data.int8, - .scale = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->scale->data[0], - .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->zero_point->data[0], - .size = modelOutput1->bytes + object_detection::Branch { + .resolution = inputImgCols/16, + .numBox = 3, + .anchor = anchor2, + .modelOutput = this->m_outputTensor1->data.int8, + .scale = (static_cast<TfLiteAffineQuantization*>( + this->m_outputTensor1->quantization.params))->scale->data[0], + .zeroPoint = (static_cast<TfLiteAffineQuantization*>( + this->m_outputTensor1->quantization.params))->zero_point->data[0], + .size = this->m_outputTensor1->bytes } }, .topN = m_topN }; /* End init */ +} +bool DetectorPostProcess::DoPostProcess() +{ /* Start postprocessing */ int originalImageWidth = originalImageSize; int originalImageHeight = originalImageSize; std::forward_list<image::Detection> detections; - GetNetworkBoxes(net, originalImageWidth, originalImageHeight, m_threshold, detections); + GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_threshold, detections); /* Do nms */ - CalculateNMS(detections, net.numClasses, m_nms); + CalculateNMS(detections, this->m_net.numClasses, m_nms); for (auto& it: detections) { float xMin = it.bbox.x - it.bbox.w / 2.0f; @@ -104,24 +113,24 @@ void DetectorPostprocessing::RunPostProcessing( float boxWidth = xMax - xMin; float boxHeight = yMax - yMin; - for (int j = 0; j < net.numClasses; ++j) { + for (int j = 0; j < this->m_net.numClasses; ++j) { if (it.prob[j] > 0) { - DetectionResult tmpResult = {}; + object_detection::DetectionResult tmpResult = {}; tmpResult.m_normalisedVal = it.prob[j]; tmpResult.m_x0 = boxX; tmpResult.m_y0 = boxY; tmpResult.m_w = boxWidth; tmpResult.m_h = boxHeight; - resultsOut.push_back(tmpResult); + this->m_results.push_back(tmpResult); } } } + return true; } - -void DetectorPostprocessing::InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det) +void DetectorPostProcess::InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det) { std::forward_list<image::Detection>::iterator it; std::forward_list<image::Detection>::iterator last_it; @@ -136,7 +145,12 @@ void DetectorPostprocessing::InsertTopNDetections(std::forward_list<image::Detec } } -void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int imageHeight, float threshold, std::forward_list<image::Detection>& detections) +void DetectorPostProcess::GetNetworkBoxes( + object_detection::Network& net, + int imageWidth, + int imageHeight, + float threshold, + std::forward_list<image::Detection>& detections) { int numClasses = net.numClasses; int num = 0; @@ -169,10 +183,14 @@ void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int i int bbox_h_offset = bbox_x_offset + 3; int bbox_scores_offset = bbox_x_offset + 5; - det.bbox.x = (static_cast<float>(net.branches[i].modelOutput[bbox_x_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale; - det.bbox.y = (static_cast<float>(net.branches[i].modelOutput[bbox_y_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale; - det.bbox.w = (static_cast<float>(net.branches[i].modelOutput[bbox_w_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale; - det.bbox.h = (static_cast<float>(net.branches[i].modelOutput[bbox_h_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale; + det.bbox.x = (static_cast<float>(net.branches[i].modelOutput[bbox_x_offset]) + - net.branches[i].zeroPoint) * net.branches[i].scale; + det.bbox.y = (static_cast<float>(net.branches[i].modelOutput[bbox_y_offset]) + - net.branches[i].zeroPoint) * net.branches[i].scale; + det.bbox.w = (static_cast<float>(net.branches[i].modelOutput[bbox_w_offset]) + - net.branches[i].zeroPoint) * net.branches[i].scale; + det.bbox.h = (static_cast<float>(net.branches[i].modelOutput[bbox_h_offset]) + - net.branches[i].zeroPoint) * net.branches[i].scale; float bbox_x, bbox_y; @@ -218,6 +236,5 @@ void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int i num -=1; } -} /* namespace object_detection */ } /* namespace app */ } /* namespace arm */ |