// // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "SSDResultDecoder.hpp" #include #include #include #include namespace od { DetectedObjects SSDResultDecoder::Decode(const common::InferenceResults& networkResults, const common::Size& outputFrameSize, const common::Size& resizedFrameSize, const std::vector& labels) { // SSD network outputs 4 tensors: bounding boxes, labels, probabilities, number of detections. if (networkResults.size() != 4) { throw std::runtime_error("Number of outputs from SSD model doesn't equal 4"); } DetectedObjects detectedObjects; const int numDetections = static_cast(std::lround(networkResults[3][0])); double longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height); double longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height); const double resizeFactor = longEdgeOutput/longEdgeInput; for (int i=0; i m_objectThreshold) { DetectedObject detectedObject; detectedObject.SetScore(networkResults[2][i]); auto classId = std::lround(networkResults[1][i]); if (classId < labels.size()) { detectedObject.SetLabel(labels[classId]); } else { detectedObject.SetLabel(std::to_string(classId)); } detectedObject.SetId(classId); // Convert SSD bbox outputs (ratios of image size) to pixel values. double topLeftY = networkResults[0][i*4 + 0] * resizedFrameSize.m_Height; double topLeftX = networkResults[0][i*4 + 1] * resizedFrameSize.m_Width; double botRightY = networkResults[0][i*4 + 2] * resizedFrameSize.m_Height; double botRightX = networkResults[0][i*4 + 3] * resizedFrameSize.m_Width; // Scale the coordinates to output frame size. topLeftY *= resizeFactor; topLeftX *= resizeFactor; botRightY *= resizeFactor; botRightX *= resizeFactor; assert(botRightX > topLeftX); assert(botRightY > topLeftY); // Internal BoundingBox stores box top left x,y and width, height. detectedObject.SetBoundingBox({static_cast(std::round(topLeftX)), static_cast(std::round(topLeftY)), static_cast(botRightX - topLeftX), static_cast(botRightY - topLeftY)}); detectedObjects.emplace_back(detectedObject); } } return detectedObjects; } SSDResultDecoder::SSDResultDecoder(float ObjectThreshold) : m_objectThreshold(ObjectThreshold) {} }// namespace od