From a7acb3cbabeb66ce647684466a04c96b2963c9c9 Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Tue, 8 Jan 2019 13:48:44 +0000 Subject: COMPMID-1849: Implement CPPDetectionPostProcessLayer * Add DetectionPostProcessLayer * Add DetectionPostProcessLayer at the graph Change-Id: I7e56f6cffc26f112d26dfe74853085bb8ec7d849 Signed-off-by: Isabella Gottardi Reviewed-on: https://review.mlplatform.org/c/1639 Reviewed-by: Giuseppe Rossini Tested-by: Arm Jenkins --- arm_compute/core/Types.h | 116 +++++++++++++++++++ arm_compute/graph/GraphBuilder.h | 15 +++ arm_compute/graph/INodeVisitor.h | 9 ++ arm_compute/graph/TypePrinter.h | 3 + arm_compute/graph/Types.h | 2 + arm_compute/graph/backends/FunctionHelpers.h | 56 ++++++++++ arm_compute/graph/backends/ValidateHelpers.h | 27 +++++ arm_compute/graph/frontend/Layers.h | 33 ++++++ .../graph/nodes/DetectionPostProcessLayerNode.h | 62 +++++++++++ arm_compute/graph/nodes/Nodes.h | 1 + arm_compute/graph/nodes/NodesFwd.h | 1 + arm_compute/runtime/CPP/CPPFunctions.h | 1 + .../CPP/functions/CPPDetectionOutputLayer.h | 9 +- .../CPP/functions/CPPDetectionPostProcessLayer.h | 123 +++++++++++++++++++++ 14 files changed, 450 insertions(+), 8 deletions(-) create mode 100644 arm_compute/graph/nodes/DetectionPostProcessLayerNode.h create mode 100644 arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h (limited to 'arm_compute') diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 2c17f273a5..6df74e7b88 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -943,6 +944,11 @@ private: std::array _steps; }; +// Bounding Box [xmin, ymin, xmax, ymax] +using BBox = std::array; +// LabelBBox used for map label and bounding box +using LabelBBox = std::map>; + /** Available Detection Output code types */ enum class DetectionOutputLayerCodeType { @@ -1071,6 +1077,116 @@ private: int _num_loc_classes; }; +/** Detection Output layer info */ +class DetectionPostProcessLayerInfo final +{ +public: + /** Default Constructor */ + DetectionPostProcessLayerInfo() + : _max_detections(), + _max_classes_per_detection(), + _nms_score_threshold(), + _iou_threshold(), + _num_classes(), + _scales_values(), + _use_regular_nms(), + _detection_per_class() + { + } + /** Constructor + * + * @param[in] max_detections Number of total detection. + * @param[in] max_classes_per_detection Number of total classes to be kept after NMS step. Used in the Fast Non-Max-Suppression + * @param[in] nms_score_threshold Threshold to be used in NMS + * @param[in] iou_threshold Threshold to be used during the intersection over union. + * @param[in] num_classes Number of classes. + * @param[in] scales_values Scales values used for decode center size boxes. + * @param[in] use_regular_nms (Optional) Boolean to determinate if use regular or fast nms. + * @param[in] detection_per_class (Optional) Number of detection per class. Used in the Regular Non-Max-Suppression + */ + DetectionPostProcessLayerInfo(unsigned int max_detections, unsigned int max_classes_per_detection, float nms_score_threshold, float iou_threshold, unsigned int num_classes, + std::array scales_values, bool use_regular_nms = false, unsigned int detection_per_class = 100) + : _max_detections(max_detections), + _max_classes_per_detection(max_classes_per_detection), + _nms_score_threshold(nms_score_threshold), + _iou_threshold(iou_threshold), + _num_classes(num_classes), + _scales_values(scales_values), + _use_regular_nms(use_regular_nms), + _detection_per_class(detection_per_class) + { + } + /** Get max detections. */ + unsigned int max_detections() const + { + return _max_detections; + } + /** Get max_classes per detection. Used in the Fast Non-Max-Suppression.*/ + unsigned int max_classes_per_detection() const + { + return _max_classes_per_detection; + } + /** Get detection per class. Used in the Regular Non-Max-Suppression */ + unsigned int detection_per_class() const + { + return _detection_per_class; + } + /** Get nms threshold. */ + float nms_score_threshold() const + { + return _nms_score_threshold; + } + /** Get intersection over union threshold. */ + float iou_threshold() const + { + return _iou_threshold; + } + /** Get num classes. */ + unsigned int num_classes() const + { + return _num_classes; + } + /** Get if use regular nms. */ + bool use_regular_nms() const + { + return _use_regular_nms; + } + /** Get y scale value. */ + float scale_value_y() const + { + // Saved as [y,x,h,w] + return _scales_values[0]; + } + /** Get x scale value. */ + float scale_value_x() const + { + // Saved as [y,x,h,w] + return _scales_values[1]; + } + /** Get h scale value. */ + float scale_value_h() const + { + // Saved as [y,x,h,w] + return _scales_values[2]; + } + /** Get w scale value. */ + float scale_value_w() const + { + // Saved as [y,x,h,w] + return _scales_values[3]; + } + +private: + unsigned int _max_detections; + unsigned int _max_classes_per_detection; + float _nms_score_threshold; + float _iou_threshold; + unsigned int _num_classes; + std::array _scales_values; + bool _use_regular_nms; + unsigned int _detection_per_class; +}; + /** Pooling Layer Information class */ class PoolingLayerInfo { diff --git a/arm_compute/graph/GraphBuilder.h b/arm_compute/graph/GraphBuilder.h index e1049ca938..dc41ed5367 100644 --- a/arm_compute/graph/GraphBuilder.h +++ b/arm_compute/graph/GraphBuilder.h @@ -217,6 +217,21 @@ public: * @return Node ID of the created node, EmptyNodeID in case of error */ static NodeID add_detection_output_node(Graph &g, NodeParams params, NodeIdxPair input_loc, NodeIdxPair input_conf, NodeIdxPair input_priorbox, const DetectionOutputLayerInfo &detect_info); + /** Adds a detection post process layer node to the graph + * + * @param[in] g Graph to add the node to + * @param[in] params Common node parameters + * @param[in] input_box_encoding Boxes input to the detection output layer node as a NodeID-Index pair + * @param[in] input_class_prediction Class prediction input to the detection output layer node as a NodeID-Index pair + * @param[in] detect_info Detection output layer parameters + * @param[in] anchors_accessor (Optional) Const Node ID that contains the anchor values + * @param[in] anchor_quant_info (Optional) Anchor quantization info + * + * @return Node ID of the created node, EmptyNodeID in case of error + */ + static NodeID add_detection_post_process_node(Graph &g, NodeParams params, NodeIdxPair input_box_encoding, NodeIdxPair input_class_prediction, + const DetectionPostProcessLayerInfo &detect_info, ITensorAccessorUPtr anchors_accessor = nullptr, + const QuantizationInfo &anchor_quant_info = QuantizationInfo()); /** Adds a Dummy node to the graph * * @note this node if for debugging purposes. Just alters the shape of the graph pipeline as requested. diff --git a/arm_compute/graph/INodeVisitor.h b/arm_compute/graph/INodeVisitor.h index 5c5b777ac9..f97906d02a 100644 --- a/arm_compute/graph/INodeVisitor.h +++ b/arm_compute/graph/INodeVisitor.h @@ -76,6 +76,11 @@ public: * @param[in] n Node to visit. */ virtual void visit(DetectionOutputLayerNode &n) = 0; + /** Visit DetectionPostProcessLayerNode. + * + * @param[in] n Node to visit. + */ + virtual void visit(DetectionPostProcessLayerNode &n) = 0; /** Visit EltwiseLayerNode. * * @param[in] n Node to visit. @@ -199,6 +204,10 @@ public: { default_visit(); } + virtual void visit(DetectionPostProcessLayerNode &n) override + { + default_visit(); + } virtual void visit(DepthwiseConvolutionLayerNode &n) override { default_visit(); diff --git a/arm_compute/graph/TypePrinter.h b/arm_compute/graph/TypePrinter.h index 9da0e6157c..e4188125b9 100644 --- a/arm_compute/graph/TypePrinter.h +++ b/arm_compute/graph/TypePrinter.h @@ -86,6 +86,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const NodeType &node_type) case NodeType::DetectionOutputLayer: os << "DetectionOutputLayer"; break; + case NodeType::DetectionPostProcessLayer: + os << "DetectionPostProcessLayer"; + break; case NodeType::DepthwiseConvolutionLayer: os << "DepthwiseConvolutionLayer"; break; diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h index 9f962425b3..8b97708a63 100644 --- a/arm_compute/graph/Types.h +++ b/arm_compute/graph/Types.h @@ -48,6 +48,7 @@ using arm_compute::PermutationVector; using arm_compute::ActivationLayerInfo; using arm_compute::DetectionOutputLayerInfo; +using arm_compute::DetectionPostProcessLayerInfo; using arm_compute::NormType; using arm_compute::NormalizationLayerInfo; using arm_compute::FullyConnectedLayerInfo; @@ -137,6 +138,7 @@ enum class NodeType DeconvolutionLayer, DepthwiseConvolutionLayer, DetectionOutputLayer, + DetectionPostProcessLayer, EltwiseLayer, FlattenLayer, FullyConnectedLayer, diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h index ed5b35c0d1..dd833061a9 100644 --- a/arm_compute/graph/backends/FunctionHelpers.h +++ b/arm_compute/graph/backends/FunctionHelpers.h @@ -644,6 +644,62 @@ std::unique_ptr create_detection_output_layer(DetectionOutputLayerNod return std::move(func); } + +/** Create a backend detection post process layer function + * + * @tparam DetectionPostProcessLayerFunction Backend detection output function + * @tparam TargetInfo Target-specific information + * + * @param[in] node Node to create the backend function for + * + * @return Backend detection post process layer function + */ +template +std::unique_ptr create_detection_post_process_layer(DetectionPostProcessLayerNode &node) +{ + validate_node(node, 3 /* expected inputs */, 4 /* expected outputs */); + + // Extract IO and info + typename TargetInfo::TensorType *input0 = get_backing_tensor(node.input(0)); + typename TargetInfo::TensorType *input1 = get_backing_tensor(node.input(1)); + typename TargetInfo::TensorType *input2 = get_backing_tensor(node.input(2)); + typename TargetInfo::TensorType *output0 = get_backing_tensor(node.output(0)); + typename TargetInfo::TensorType *output1 = get_backing_tensor(node.output(1)); + typename TargetInfo::TensorType *output2 = get_backing_tensor(node.output(2)); + typename TargetInfo::TensorType *output3 = get_backing_tensor(node.output(3)); + const DetectionPostProcessLayerInfo detect_info = node.detection_post_process_info(); + + ARM_COMPUTE_ERROR_ON(input0 == nullptr); + ARM_COMPUTE_ERROR_ON(input1 == nullptr); + ARM_COMPUTE_ERROR_ON(input2 == nullptr); + ARM_COMPUTE_ERROR_ON(output0 == nullptr); + ARM_COMPUTE_ERROR_ON(output1 == nullptr); + ARM_COMPUTE_ERROR_ON(output2 == nullptr); + ARM_COMPUTE_ERROR_ON(output3 == nullptr); + + // Create and configure function + auto func = support::cpp14::make_unique(); + func->configure(input0, input1, input2, output0, output1, output2, output3, detect_info); + + // Log info + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " + << node.name() + << " Type: " << node.type() + << " Target: " << TargetInfo::TargetType + << " Data Type: " << input0->info()->data_type() + << " Input0 shape: " << input0->info()->tensor_shape() + << " Input1 shape: " << input1->info()->tensor_shape() + << " Input2 shape: " << input2->info()->tensor_shape() + << " Output0 shape: " << output0->info()->tensor_shape() + << " Output1 shape: " << output1->info()->tensor_shape() + << " Output2 shape: " << output2->info()->tensor_shape() + << " Output3 shape: " << output3->info()->tensor_shape() + << " DetectionPostProcessLayer info: " << detect_info + << std::endl); + + return std::move(func); +} + /** Create a backend element-wise operation layer function * * @tparam EltwiseFunctions Backend element-wise function diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h index 3a5686336b..13de273bdf 100644 --- a/arm_compute/graph/backends/ValidateHelpers.h +++ b/arm_compute/graph/backends/ValidateHelpers.h @@ -228,6 +228,33 @@ Status validate_detection_output_layer(DetectionOutputLayerNode &node) return DetectionOutputLayer::validate(input0, input1, input2, output, detect_info); } +/** Validates a detection post process layer node + * + * @tparam DetectionPostProcessLayer DetectionOutput layer type + * + * @param[in] node Node to validate + * + * @return Status + */ +template +Status validate_detection_post_process_layer(DetectionPostProcessLayerNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionPostProcessLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 4); + + // Extract IO and info + arm_compute::ITensorInfo *input0 = get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *input1 = get_backing_tensor_info(node.input(1)); + arm_compute::ITensorInfo *input2 = get_backing_tensor_info(node.input(2)); + arm_compute::ITensorInfo *output0 = get_backing_tensor_info(node.output(0)); + arm_compute::ITensorInfo *output1 = get_backing_tensor_info(node.output(1)); + arm_compute::ITensorInfo *output2 = get_backing_tensor_info(node.output(2)); + arm_compute::ITensorInfo *output3 = get_backing_tensor_info(node.output(3)); + const DetectionPostProcessLayerInfo detect_info = node.detection_post_process_info(); + + return DetectionPostProcessLayer::validate(input0, input1, input2, output0, output1, output2, output3, detect_info); +} /** Validates a Generate Proposals layer node * diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h index 3fc4af46d5..27a0cd3026 100644 --- a/arm_compute/graph/frontend/Layers.h +++ b/arm_compute/graph/frontend/Layers.h @@ -493,6 +493,39 @@ private: SubStream _ss_prior; DetectionOutputLayerInfo _detect_info; }; +/** DetectionOutputPostProcess Layer */ +class DetectionPostProcessLayer final : public ILayer +{ +public: + /** Construct a detection output layer. + * + * @param[in] sub_stream_class_prediction Class prediction graph sub-stream. + * @param[in] detect_info DetectionOutput parameters. + * @param[in] anchors Accessor to get anchors tensor data from. + * @param[in] out_quant_info (Optional) Output quantization info + */ + DetectionPostProcessLayer(SubStream &&sub_stream_class_prediction, DetectionPostProcessLayerInfo detect_info, ITensorAccessorUPtr anchors, + const QuantizationInfo out_quant_info = QuantizationInfo()) + : _sub_stream_class_prediction(std::move(sub_stream_class_prediction)), _detect_info(detect_info), _anchors(std::move(anchors)), _out_quant_info(std::move(out_quant_info)) + { + } + + NodeID create_layer(IStream &s) override + { + ARM_COMPUTE_ERROR_ON(_anchors == nullptr); + + NodeParams common_params = { name(), s.hints().target_hint }; + NodeIdxPair input_box_encoding = { s.tail_node(), 0 }; + NodeIdxPair input_class_prediction = { _sub_stream_class_prediction.tail_node(), 0 }; + return GraphBuilder::add_detection_post_process_node(s.graph(), common_params, input_box_encoding, input_class_prediction, _detect_info, std::move(_anchors), std::move(_out_quant_info)); + } + +private: + SubStream _sub_stream_class_prediction; + DetectionPostProcessLayerInfo _detect_info; + ITensorAccessorUPtr _anchors; + const QuantizationInfo _out_quant_info; +}; /** Dummy Layer */ class DummyLayer final : public ILayer { diff --git a/arm_compute/graph/nodes/DetectionPostProcessLayerNode.h b/arm_compute/graph/nodes/DetectionPostProcessLayerNode.h new file mode 100644 index 0000000000..76b1d8ce98 --- /dev/null +++ b/arm_compute/graph/nodes/DetectionPostProcessLayerNode.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH_DETECTION_POST_PROCESS_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH_DETECTION_POST_PROCESS_LAYER_NODE_H__ + +#include "arm_compute/graph/INode.h" + +namespace arm_compute +{ +namespace graph +{ +/** DetectionPostProcess Layer node */ +class DetectionPostProcessLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] detection_info DetectionPostProcess Layer information + */ + DetectionPostProcessLayerNode(DetectionPostProcessLayerInfo detection_info); + /** DetectionPostProcess metadata accessor + * + * @return DetectionPostProcess Layer info + */ + DetectionPostProcessLayerInfo detection_post_process_info() const; + + // Inherited overridden methods: + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + DetectionPostProcessLayerInfo _info; + + static const int kNumCoordBox = 4; + static const int kBatchSize = 1; +}; +} // namespace graph +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH_DETECTION_POST_PROCESS_LAYER_NODE_H__ */ \ No newline at end of file diff --git a/arm_compute/graph/nodes/Nodes.h b/arm_compute/graph/nodes/Nodes.h index 52e2f88528..1586270093 100644 --- a/arm_compute/graph/nodes/Nodes.h +++ b/arm_compute/graph/nodes/Nodes.h @@ -34,6 +34,7 @@ #include "arm_compute/graph/nodes/DeconvolutionLayerNode.h" #include "arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h" #include "arm_compute/graph/nodes/DetectionOutputLayerNode.h" +#include "arm_compute/graph/nodes/DetectionPostProcessLayerNode.h" #include "arm_compute/graph/nodes/DummyNode.h" #include "arm_compute/graph/nodes/EltwiseLayerNode.h" #include "arm_compute/graph/nodes/FlattenLayerNode.h" diff --git a/arm_compute/graph/nodes/NodesFwd.h b/arm_compute/graph/nodes/NodesFwd.h index 2c89679902..53f2a6a1b5 100644 --- a/arm_compute/graph/nodes/NodesFwd.h +++ b/arm_compute/graph/nodes/NodesFwd.h @@ -40,6 +40,7 @@ class ConvolutionLayerNode; class DeconvolutionLayerNode; class DepthwiseConvolutionLayerNode; class DetectionOutputLayerNode; +class DetectionPostProcessLayerNode; class DummyNode; class EltwiseLayerNode; class FlattenLayerNode; diff --git a/arm_compute/runtime/CPP/CPPFunctions.h b/arm_compute/runtime/CPP/CPPFunctions.h index 1dff03f349..743929fae8 100644 --- a/arm_compute/runtime/CPP/CPPFunctions.h +++ b/arm_compute/runtime/CPP/CPPFunctions.h @@ -27,6 +27,7 @@ /* Header regrouping all the CPP functions */ #include "arm_compute/runtime/CPP/functions/CPPBoxWithNonMaximaSuppressionLimit.h" #include "arm_compute/runtime/CPP/functions/CPPDetectionOutputLayer.h" +#include "arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h" #include "arm_compute/runtime/CPP/functions/CPPNonMaximumSuppression.h" #include "arm_compute/runtime/CPP/functions/CPPPermute.h" #include "arm_compute/runtime/CPP/functions/CPPTopKV.h" diff --git a/arm_compute/runtime/CPP/functions/CPPDetectionOutputLayer.h b/arm_compute/runtime/CPP/functions/CPPDetectionOutputLayer.h index 71be8a0ad8..4e1b8f2a74 100644 --- a/arm_compute/runtime/CPP/functions/CPPDetectionOutputLayer.h +++ b/arm_compute/runtime/CPP/functions/CPPDetectionOutputLayer.h @@ -28,17 +28,10 @@ #include "arm_compute/core/Types.h" -#include - namespace arm_compute { class ITensor; -// Normalized Bounding Box [xmin, ymin, xmax, ymax] -using NormalizedBBox = std::array; -// LabelBBox used for map label and bounding box -using LabelBBox = std::map>; - /** CPP Function to generate the detection output based on location and confidence * predictions by doing non maximum suppression. * @@ -91,7 +84,7 @@ private: std::vector _all_location_predictions; std::vector>> _all_confidence_scores; - std::vector _all_prior_bboxes; + std::vector _all_prior_bboxes; std::vector> _all_prior_variances; std::vector _all_decode_bboxes; std::vector>> _all_indices; diff --git a/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h b/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h new file mode 100644 index 0000000000..c13def67c7 --- /dev/null +++ b/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_CPP_DETECTION_POSTPROCESS_H__ +#define __ARM_COMPUTE_CPP_DETECTION_POSTPROCESS_H__ + +#include "arm_compute/runtime/CPP/ICPPSimpleFunction.h" + +#include "arm_compute/core/Types.h" +#include "arm_compute/runtime/CPP/functions/CPPNonMaximumSuppression.h" +#include "arm_compute/runtime/IMemoryManager.h" +#include "arm_compute/runtime/MemoryGroup.h" +#include "arm_compute/runtime/Tensor.h" + +#include + +namespace arm_compute +{ +class ITensor; + +/** CPP Function to generate the detection output based on center size encoded boxes, class prediction and anchors + * by doing non maximum suppression. + * + * @note Intended for use with MultiBox detection method. + */ +class CPPDetectionPostProcessLayer : public IFunction +{ +public: + /** Constructor */ + CPPDetectionPostProcessLayer(std::shared_ptr memory_manager = nullptr); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CPPDetectionPostProcessLayer(const CPPDetectionPostProcessLayer &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CPPDetectionPostProcessLayer &operator=(const CPPDetectionPostProcessLayer &) = delete; + /** Configure the detection output layer CPP function + * + * @param[in] input_box_encoding The bounding box input tensor. Data types supported: F32, QASYMM8. + * @param[in] input_score The class prediction input tensor. Data types supported: Same as @p input_box_encoding. + * @param[in] input_anchors The anchors input tensor. Data types supported: Same as @p input_box_encoding. + * @param[out] output_boxes The boxes output tensor. Data types supported: F32. + * @param[out] output_classes The classes output tensor. Data types supported: Same as @p output_boxes. + * @param[out] output_scores The scores output tensor. Data types supported: Same as @p output_boxes. + * @param[out] num_detection The number of output detection. Data types supported: Same as @p output_boxes. + * @param[in] info (Optional) DetectionPostProcessLayerInfo information. + * + * @note Output contains all the detections. Of those, only the ones selected by the valid region are valid. + */ + void configure(const ITensor *input_box_encoding, const ITensor *input_score, const ITensor *input_anchors, + ITensor *output_boxes, ITensor *output_classes, ITensor *output_scores, ITensor *num_detection, DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo()); + /** Static function to check if given info will lead to a valid configuration of @ref CPPDetectionPostProcessLayer + * + * @param[in] input_box_encoding The bounding box input tensor info. Data types supported: F32, QASYMM8. + * @param[in] input_class_score The class prediction input tensor info. Data types supported: F32, QASYMM8. + * @param[in] input_anchors The anchors input tensor. Data types supported: F32, QASYMM8. + * @param[out] output_boxes The output tensor. Data types supported: F32. + * @param[out] output_classes The output tensor. Data types supported: Same as @p output_boxes. + * @param[out] output_scores The output tensor. Data types supported: Same as @p output_boxes. + * @param[out] num_detection The number of output detection. Data types supported: Same as @p output_boxes. + * @param[in] info (Optional) DetectionPostProcessLayerInfo information. + * + * @return a status + */ + static Status validate(const ITensorInfo *input_box_encoding, const ITensorInfo *input_class_score, const ITensorInfo *input_anchors, + ITensorInfo *output_boxes, ITensorInfo *output_classes, ITensorInfo *output_scores, ITensorInfo *num_detection, + DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo()); + // Inherited methods overridden: + void run() override; + +private: + MemoryGroup _memory_group; + CPPNonMaximumSuppression _nms; + const ITensor *_input_box_encoding; + const ITensor *_input_scores; + const ITensor *_input_anchors; + ITensor *_output_boxes; + ITensor *_output_classes; + ITensor *_output_scores; + ITensor *_num_detection; + DetectionPostProcessLayerInfo _info; + + const unsigned int _kBatchSize = 1; + const unsigned int _kNumCoordBox = 4; + unsigned int _num_boxes; + unsigned int _num_classes_with_background; + unsigned int _num_max_detected_boxes; + + Tensor _decoded_boxes; + Tensor _decoded_scores; + Tensor _selected_indices; + Tensor _class_scores; + const ITensor *_input_scores_to_use; + + // Intermediate results + std::vector _result_idx_boxes_after_nms; + std::vector _result_classes_after_nms; + std::vector _result_scores_after_nms; + std::vector _sorted_indices; + + // Temporary values + std::vector _box_scores; +}; +} // namespace arm_compute +#endif /* __ARM_COMPUTE_CPP_DETECTION_POSTPROCESS_H__ */ -- cgit v1.2.1