aboutsummaryrefslogtreecommitdiff
path: root/samples/ObjectDetection/src/YoloResultDecoder.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'samples/ObjectDetection/src/YoloResultDecoder.cpp')
-rw-r--r--samples/ObjectDetection/src/YoloResultDecoder.cpp100
1 files changed, 100 insertions, 0 deletions
diff --git a/samples/ObjectDetection/src/YoloResultDecoder.cpp b/samples/ObjectDetection/src/YoloResultDecoder.cpp
new file mode 100644
index 0000000000..ffbf7cb68d
--- /dev/null
+++ b/samples/ObjectDetection/src/YoloResultDecoder.cpp
@@ -0,0 +1,100 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "YoloResultDecoder.hpp"
+
+#include "NonMaxSuppression.hpp"
+
+#include <cassert>
+#include <stdexcept>
+
+namespace od
+{
+
+DetectedObjects YoloResultDecoder::Decode(const InferenceResults& networkResults,
+ const Size& outputFrameSize,
+ const Size& resizedFrameSize,
+ const std::vector<std::string>& labels)
+{
+
+ // Yolo v3 network outputs 1 tensor
+ if (networkResults.size() != 1)
+ {
+ throw std::runtime_error("Number of outputs from Yolo model doesn't equal 1");
+ }
+ auto element_step = m_boxElements + m_confidenceElements + m_numClasses;
+
+ float longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height);
+ float longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height);
+ const float resizeFactor = longEdgeOutput/longEdgeInput;
+
+ DetectedObjects detectedObjects;
+ DetectedObjects resultsAfterNMS;
+
+ for (const InferenceResult& result : networkResults)
+ {
+ for (unsigned int i = 0; i < m_numBoxes; ++i)
+ {
+ const float* cur_box = &result[i * element_step];
+ // Objectness score
+ if (cur_box[4] > m_objectThreshold)
+ {
+ for (unsigned int classIndex = 0; classIndex < m_numClasses; ++classIndex)
+ {
+ const float class_prob = cur_box[4] * cur_box[5 + classIndex];
+
+ // class confidence
+
+ if (class_prob > m_ClsThreshold)
+ {
+ DetectedObject detectedObject;
+
+ detectedObject.SetScore(class_prob);
+
+ float topLeftX = cur_box[0] * resizeFactor;
+ float topLeftY = cur_box[1] * resizeFactor;
+ float botRightX = cur_box[2] * resizeFactor;
+ float botRightY = cur_box[3] * resizeFactor;
+
+ assert(botRightX > topLeftX);
+ assert(botRightY > topLeftY);
+
+ detectedObject.SetBoundingBox({static_cast<int>(topLeftX),
+ static_cast<int>(topLeftY),
+ static_cast<unsigned int>(botRightX-topLeftX),
+ static_cast<unsigned int>(botRightY-topLeftY)});
+ if(labels.size() > classIndex)
+ {
+ detectedObject.SetLabel(labels.at(classIndex));
+ }
+ else
+ {
+ detectedObject.SetLabel(std::to_string(classIndex));
+ }
+ detectedObject.SetId(classIndex);
+ detectedObjects.emplace_back(detectedObject);
+ }
+ }
+ }
+ }
+
+ std::vector<int> keepIndiciesAfterNMS = od::NonMaxSuppression(detectedObjects, m_NmsThreshold);
+
+ for (const int ind: keepIndiciesAfterNMS)
+ {
+ resultsAfterNMS.emplace_back(detectedObjects[ind]);
+ }
+ }
+
+ return resultsAfterNMS;
+}
+
+YoloResultDecoder::YoloResultDecoder(float NMSThreshold, float ClsThreshold, float ObjectThreshold)
+ : m_NmsThreshold(NMSThreshold), m_ClsThreshold(ClsThreshold), m_objectThreshold(ObjectThreshold) {}
+
+}// namespace od
+
+
+