aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2018-11-27 08:51:10 +0000
committerIsabella Gottardi <isabella.gottardi@arm.com>2018-12-13 11:21:59 +0000
commit7234ed8c3d07c76963eb3bce9530994421ad7e67 (patch)
tree6834d5fc3cc23eb47bcfad3a4191d91c87c8f9e0 /examples
parent0e7210de821a7d1164017b8b9e11b53805185b25 (diff)
downloadComputeLibrary-7234ed8c3d07c76963eb3bce9530994421ad7e67.tar.gz
COMPMID-1808: Add Detection Output Layer to the GraphAPI
COMPMID-1710: Integrate Detection ouput in MobilenetSSD graph example Change-Id: I384d1eb492ef14ece58f2023ad7bbc16f834450b Reviewed-on: https://review.mlplatform.org/356 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/graph_ssd_mobilenet.cpp37
1 files changed, 19 insertions, 18 deletions
diff --git a/examples/graph_ssd_mobilenet.cpp b/examples/graph_ssd_mobilenet.cpp
index 95a4dcc66b..676c5e9167 100644
--- a/examples/graph_ssd_mobilenet.cpp
+++ b/examples/graph_ssd_mobilenet.cpp
@@ -39,12 +39,9 @@ public:
GraphSSDMobilenetExample()
: cmd_parser(), common_opts(cmd_parser), common_params(), graph(0, "MobileNetSSD")
{
- mbox_loc_opt = cmd_parser.add_option<SimpleOption<std::string>>("mbox_loc_opt", "");
- mbox_loc_opt->set_help("Filename containing the reference values for the graph branch mbox_loc_opt.");
- mbox_conf_flatten_opt = cmd_parser.add_option<SimpleOption<std::string>>("mbox_conf_flatten", "");
- mbox_conf_flatten_opt->set_help("Filename containing the reference values for the graph branch mbox_conf_flatten.");
- mbox_priorbox_opt = cmd_parser.add_option<SimpleOption<std::string>>("mbox_priorbox", "");
- mbox_priorbox_opt->set_help("Filename containing the reference values for the graph branch mbox_priorbox.");
+ // Add topk option
+ keep_topk_opt = cmd_parser.add_option<SimpleOption<int>>("topk", 100);
+ keep_topk_opt->set_help("Top k detections results per image.");
}
GraphSSDMobilenetExample(const GraphSSDMobilenetExample &) = delete;
GraphSSDMobilenetExample &operator=(const GraphSSDMobilenetExample &) = delete;
@@ -162,8 +159,6 @@ public:
mbox_loc << ConcatLayer(std::move(conv_11_mbox_loc), std::move(conv_13_mbox_loc), conv_14_2_mbox_loc, std::move(conv_15_2_mbox_loc),
std::move(conv_16_2_mbox_loc), std::move(conv_17_2_mbox_loc));
- mbox_loc << OutputLayer(get_npy_output_accessor(mbox_loc_opt->value(), TensorShape(7668U), DataType::F32));
-
//mbox_conf
SubStream conv_11_mbox_conf(conv_11);
conv_11_mbox_conf << get_node_C(conv_11, data_path, "conv11_mbox_conf", 63, PadStrideInfo(1, 1, 0, 0));
@@ -190,8 +185,6 @@ public:
mbox_conf << SoftmaxLayer().set_name("mbox_conf/softmax");
mbox_conf << FlattenLayer().set_name("mbox_conf/flat");
- mbox_conf << OutputLayer(get_npy_output_accessor(mbox_conf_flatten_opt->value(), TensorShape(40257U), DataType::F32));
-
const std::vector<float> priorbox_variances = { 0.1f, 0.1f, 0.2f, 0.2f };
const float priorbox_offset = 0.5f;
const std::vector<float> priorbox_aspect_ratios = { 2.f, 3.f };
@@ -235,7 +228,19 @@ public:
std::move(conv_11_mbox_priorbox), std::move(conv_13_mbox_priorbox), std::move(conv_14_2_mbox_priorbox),
std::move(conv_15_2_mbox_priorbox), std::move(conv_16_2_mbox_priorbox), std::move(conv_17_2_mbox_priorbox));
- mbox_priorbox << OutputLayer(get_npy_output_accessor(mbox_priorbox_opt->value(), TensorShape(7668U, 2U, 1U), DataType::F32));
+ const int num_classes = 21;
+ const bool share_location = true;
+ const DetectionOutputLayerCodeType detection_type = DetectionOutputLayerCodeType::CENTER_SIZE;
+ const int keep_top_k = keep_topk_opt->value();
+ const float nms_threshold = 0.45f;
+ const int label_id_background = 0;
+ const float conf_thrs = 0.25f;
+ const int top_k = 100;
+
+ SubStream detection_ouput(mbox_loc);
+ detection_ouput << DetectionOutputLayer(std::move(mbox_conf), std::move(mbox_priorbox),
+ DetectionOutputLayerInfo(num_classes, share_location, detection_type, keep_top_k, nms_threshold, top_k, label_id_background, conf_thrs));
+ detection_ouput << OutputLayer(get_detection_output_accessor(common_params, { tensor_shape }));
// Finalize graph
GraphConfig config;
@@ -256,13 +261,9 @@ public:
private:
CommandLineParser cmd_parser;
CommonGraphOptions common_opts;
-
- SimpleOption<std::string> *mbox_loc_opt{ nullptr };
- SimpleOption<std::string> *mbox_conf_flatten_opt{ nullptr };
- SimpleOption<std::string> *mbox_priorbox_opt{ nullptr };
-
- CommonGraphParams common_params;
- Stream graph;
+ SimpleOption<int> *keep_topk_opt{ nullptr };
+ CommonGraphParams common_params;
+ Stream graph;
ConcatLayer get_node_A(IStream &master_graph, const std::string &data_path, std::string &&param_path,
unsigned int conv_filt,