summaryrefslogtreecommitdiff
path: root/source/use_case/object_detection/include/DetectorPostProcessing.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/object_detection/include/DetectorPostProcessing.hpp')
-rw-r--r--source/use_case/object_detection/include/DetectorPostProcessing.hpp249
1 files changed, 194 insertions, 55 deletions
diff --git a/source/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/use_case/object_detection/include/DetectorPostProcessing.hpp
index 9a8549c..3e9c819 100644
--- a/source/use_case/object_detection/include/DetectorPostProcessing.hpp
+++ b/source/use_case/object_detection/include/DetectorPostProcessing.hpp
@@ -1,55 +1,194 @@
-/*
- * Copyright (c) 2022 Arm Limited. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef DETECTOR_POST_PROCESSING_HPP
-#define DETECTOR_POST_PROCESSING_HPP
-
-#include "UseCaseCommonUtils.hpp"
-#include "DetectionResult.hpp"
-
-namespace arm {
-namespace app {
-
-#if DISPLAY_RGB_IMAGE
-#define FORMAT_MULTIPLY_FACTOR 3
-#else
-#define FORMAT_MULTIPLY_FACTOR 1
-#endif /* DISPLAY_RGB_IMAGE */
-
- /**
- * @brief Post processing part of Yolo object detection CNN
- * @param[in] img_in Pointer to the input image,detection bounding boxes drown on it.
- * @param[in] model_output Output tesnsors after CNN invoked
- * @param[out] results_out Vector of detected results.
- * @return void
- **/
-void RunPostProcessing(uint8_t *img_in,TfLiteTensor* model_output[2],std::vector<arm::app::DetectionResult> & results_out);
-
-
- /**
- * @brief Converts RGB image to grayscale
- * @param[in] rgb Pointer to RGB input image
- * @param[out] gray Pointer to RGB out image
- * @param[in] im_w Input image width
- * @param[in] im_h Input image height
- * @return void
- **/
-void RgbToGrayscale(const uint8_t *rgb,uint8_t *gray, int im_w,int im_h);
-
-} /* namespace app */
-} /* namespace arm */
-
-#endif /* DETECTOR_POST_PROCESSING_HPP */
+/*
+ * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef DETECTOR_POST_PROCESSING_HPP
+#define DETECTOR_POST_PROCESSING_HPP
+
+#include "UseCaseCommonUtils.hpp"
+#include "DetectionResult.hpp"
+#include "YoloFastestModel.hpp"
+
+#include <forward_list>
+
+namespace arm {
+namespace app {
+namespace object_detection {
+
+ struct Branch {
+ int resolution;
+ int numBox;
+ const float* anchor;
+ int8_t* modelOutput;
+ float scale;
+ int zeroPoint;
+ size_t size;
+ };
+
+ struct Network {
+ int inputWidth;
+ int inputHeight;
+ int numClasses;
+ std::vector<Branch> branches;
+ int topN;
+ };
+
+
+ struct Box {
+ float x;
+ float y;
+ float w;
+ float h;
+ };
+
+ struct Detection {
+ Box bbox;
+ std::vector<float> prob;
+ float objectness;
+ };
+
+ /**
+ * @brief Helper class to manage tensor post-processing for "object_detection"
+ * output.
+ */
+ class DetectorPostprocessing {
+ public:
+ /**
+ * @brief Constructor.
+ * @param[in] threshold Post-processing threshold.
+ * @param[in] nms Non-maximum Suppression threshold.
+ * @param[in] numClasses Number of classes.
+ * @param[in] topN Top N for each class.
+ **/
+ DetectorPostprocessing(float threshold = 0.5f,
+ float nms = 0.45f,
+ int numClasses = 1,
+ int topN = 0);
+
+ /**
+ * @brief Post processing part of Yolo object detection CNN.
+ * @param[in] imgIn Pointer to the input image,detection bounding boxes drown on it.
+ * @param[in] imgRows Number of rows in the input image.
+ * @param[in] imgCols Number of columns in the input image.
+ * @param[in] modelOutput Output tensors after CNN invoked.
+ * @param[out] resultsOut Vector of detected results.
+ **/
+ void RunPostProcessing(uint8_t* imgIn,
+ uint32_t imgRows,
+ uint32_t imgCols,
+ TfLiteTensor* modelOutput0,
+ TfLiteTensor* modelOutput1,
+ std::vector<DetectionResult>& resultsOut);
+
+ private:
+ float m_threshold; /* Post-processing threshold */
+ float m_nms; /* NMS threshold */
+ int m_numClasses; /* Number of classes */
+ int m_topN; /* TopN */
+
+ /**
+ * @brief Calculate the Sigmoid function of the give value.
+ * @param[in] x Value.
+ * @return Sigmoid value of the input.
+ **/
+ float Sigmoid(float x);
+
+ /**
+ * @brief Insert the given Detection in the list.
+ * @param[in] detections List of detections.
+ * @param[in] det Detection to be inserted.
+ **/
+ void InsertTopNDetections(std::forward_list<Detection>& detections, Detection& det);
+
+ /**
+ * @brief Given a Network calculate the detection boxes.
+ * @param[in] net Network.
+ * @param[in] imageWidth Original image width.
+ * @param[in] imageHeight Original image height.
+ * @param[in] threshold Detections threshold.
+ * @param[out] detections Detection boxes.
+ **/
+ void GetNetworkBoxes(Network& net,
+ int imageWidth,
+ int imageHeight,
+ float threshold,
+ std::forward_list<Detection>& detections);
+
+ /**
+ * @brief Calculate the 1D overlap.
+ * @param[in] x1Center First center point.
+ * @param[in] width1 First width.
+ * @param[in] x2Center Second center point.
+ * @param[in] width2 Second width.
+ * @return The overlap between the two lines.
+ **/
+ float Calculate1DOverlap(float x1Center, float width1, float x2Center, float width2);
+
+ /**
+ * @brief Calculate the intersection between the two given boxes.
+ * @param[in] box1 First box.
+ * @param[in] box2 Second box.
+ * @return The intersection value.
+ **/
+ float CalculateBoxIntersect(Box& box1, Box& box2);
+
+ /**
+ * @brief Calculate the union between the two given boxes.
+ * @param[in] box1 First box.
+ * @param[in] box2 Second box.
+ * @return The two given boxes union value.
+ **/
+ float CalculateBoxUnion(Box& box1, Box& box2);
+ /**
+ * @brief Calculate the intersection over union between the two given boxes.
+ * @param[in] box1 First box.
+ * @param[in] box2 Second box.
+ * @return The intersection over union value.
+ **/
+ float CalculateBoxIOU(Box& box1, Box& box2);
+
+ /**
+ * @brief Calculate the Non-Maxima suppression on the given detection boxes.
+ * @param[in] detections Detection boxes.
+ * @param[in] classes Number of classes.
+ * @param[in] iouThreshold Intersection over union threshold.
+ * @return true or false based on execution success.
+ **/
+ void CalculateNMS(std::forward_list<Detection>& detections, int classes, float iouThreshold);
+
+ /**
+ * @brief Draw on the given image a bounding box starting at (boxX, boxY).
+ * @param[in/out] imgIn Image.
+ * @param[in] imWidth Image width.
+ * @param[in] imHeight Image height.
+ * @param[in] boxX Axis X starting point.
+ * @param[in] boxY Axis Y starting point.
+ * @param[in] boxWidth Box width.
+ * @param[in] boxHeight Box height.
+ **/
+ void DrawBoxOnImage(uint8_t* imgIn,
+ int imWidth,
+ int imHeight,
+ int boxX,
+ int boxY,
+ int boxWidth,
+ int boxHeight);
+ };
+
+} /* namespace object_detection */
+} /* namespace app */
+} /* namespace arm */
+
+#endif /* DETECTOR_POST_PROCESSING_HPP */