// // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "YoloResultDecoder.hpp" #include "NonMaxSuppression.hpp" #include #include namespace od { DetectedObjects YoloResultDecoder::Decode(const common::InferenceResults& networkResults, const common::Size& outputFrameSize, const common::Size& resizedFrameSize, const std::vector& labels) { // Yolo v3 network outputs 1 tensor if (networkResults.size() != 1) { throw std::runtime_error("Number of outputs from Yolo model doesn't equal 1"); } auto element_step = m_boxElements + m_confidenceElements + m_numClasses; float longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height); float longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height); const float resizeFactor = longEdgeOutput/longEdgeInput; DetectedObjects detectedObjects; DetectedObjects resultsAfterNMS; for (const common::InferenceResult& result : networkResults) { for (unsigned int i = 0; i < m_numBoxes; ++i) { const float* cur_box = &result[i * element_step]; // Objectness score if (cur_box[4] > m_objectThreshold) { for (unsigned int classIndex = 0; classIndex < m_numClasses; ++classIndex) { const float class_prob = cur_box[4] * cur_box[5 + classIndex]; // class confidence if (class_prob > m_ClsThreshold) { DetectedObject detectedObject; detectedObject.SetScore(class_prob); float topLeftX = cur_box[0] * resizeFactor; float topLeftY = cur_box[1] * resizeFactor; float botRightX = cur_box[2] * resizeFactor; float botRightY = cur_box[3] * resizeFactor; assert(botRightX > topLeftX); assert(botRightY > topLeftY); detectedObject.SetBoundingBox({static_cast(topLeftX), static_cast(topLeftY), static_cast(botRightX-topLeftX), static_cast(botRightY-topLeftY)}); if(labels.size() > classIndex) { detectedObject.SetLabel(labels.at(classIndex)); } else { detectedObject.SetLabel(std::to_string(classIndex)); } detectedObject.SetId(classIndex); detectedObjects.emplace_back(detectedObject); } } } } std::vector keepIndiciesAfterNMS = od::NonMaxSuppression(detectedObjects, m_NmsThreshold); for (const int ind: keepIndiciesAfterNMS) { resultsAfterNMS.emplace_back(detectedObjects[ind]); } } return resultsAfterNMS; } YoloResultDecoder::YoloResultDecoder(float NMSThreshold, float ClsThreshold, float ObjectThreshold) : m_NmsThreshold(NMSThreshold), m_ClsThreshold(ClsThreshold), m_objectThreshold(ObjectThreshold) {} }// namespace od