diff options
Diffstat (limited to 'samples/ObjectDetection/src/NetworkPipeline.cpp')
-rw-r--r-- | samples/ObjectDetection/src/NetworkPipeline.cpp | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/samples/ObjectDetection/src/NetworkPipeline.cpp b/samples/ObjectDetection/src/NetworkPipeline.cpp new file mode 100644 index 0000000000..7f05882fc4 --- /dev/null +++ b/samples/ObjectDetection/src/NetworkPipeline.cpp @@ -0,0 +1,102 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "NetworkPipeline.hpp" +#include "ImageUtils.hpp" + +namespace od +{ + +ObjDetectionPipeline::ObjDetectionPipeline(std::unique_ptr<ArmnnNetworkExecutor> executor, + std::unique_ptr<IDetectionResultDecoder> decoder) : + m_executor(std::move(executor)), + m_decoder(std::move(decoder)){} + +void od::ObjDetectionPipeline::Inference(const cv::Mat& processed, InferenceResults& result) +{ + m_executor->Run(processed.data, processed.total() * processed.elemSize(), result); +} + +void ObjDetectionPipeline::PostProcessing(InferenceResults& inferenceResult, + const std::function<void (DetectedObjects)>& callback) +{ + DetectedObjects detections = m_decoder->Decode(inferenceResult, m_inputImageSize, + m_executor->GetImageAspectRatio(), {}); + if (callback) + { + callback(detections); + } +} + +void ObjDetectionPipeline::PreProcessing(const cv::Mat& frame, cv::Mat& processed) +{ + m_inputImageSize.m_Height = frame.rows; + m_inputImageSize.m_Width = frame.cols; + ResizeWithPad(frame, processed, m_processedFrame, m_executor->GetImageAspectRatio()); +} + +MobileNetSSDv1::MobileNetSSDv1(std::unique_ptr<ArmnnNetworkExecutor> executor, + float objectThreshold) : + ObjDetectionPipeline(std::move(executor), + std::make_unique<SSDResultDecoder>(objectThreshold)) +{} + +void MobileNetSSDv1::PreProcessing(const cv::Mat& frame, cv::Mat& processed) +{ + ObjDetectionPipeline::PreProcessing(frame, processed); + if (m_executor->GetInputDataType() == armnn::DataType::Float32) + { + // [0, 255] => [-1.0, 1.0] + processed.convertTo(processed, CV_32FC3, 1 / 127.5, -1); + } +} + +YoloV3Tiny::YoloV3Tiny(std::unique_ptr<ArmnnNetworkExecutor> executor, + float NMSThreshold, float ClsThreshold, float ObjectThreshold) : + ObjDetectionPipeline(std::move(executor), + std::move(std::make_unique<YoloResultDecoder>(NMSThreshold, + ClsThreshold, + ObjectThreshold))) +{} + +void YoloV3Tiny::PreProcessing(const cv::Mat& frame, cv::Mat& processed) +{ + ObjDetectionPipeline::PreProcessing(frame, processed); + if (m_executor->GetInputDataType() == armnn::DataType::Float32) + { + processed.convertTo(processed, CV_32FC3); + } +} + +IPipelinePtr CreatePipeline(od::ODPipelineOptions& config) +{ + auto executor = std::make_unique<od::ArmnnNetworkExecutor>(config.m_ModelFilePath, config.m_backends); + + if (config.m_ModelName == "SSD_MOBILE") + { + float detectionThreshold = 0.6; + + return std::make_unique<od::MobileNetSSDv1>(std::move(executor), + detectionThreshold + ); + } + else if (config.m_ModelName == "YOLO_V3_TINY") + { + float NMSThreshold = 0.6f; + float ClsThreshold = 0.6f; + float ObjectThreshold = 0.6f; + return std::make_unique<od::YoloV3Tiny>(std::move(executor), + NMSThreshold, + ClsThreshold, + ObjectThreshold + ); + } + else + { + throw std::invalid_argument("Unknown Model name: " + config.m_ModelName + " supplied by user."); + } + +} +}// namespace od
\ No newline at end of file |