aboutsummaryrefslogtreecommitdiff
path: root/utils/GraphUtils.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/GraphUtils.h')
-rw-r--r--utils/GraphUtils.h63
1 files changed, 63 insertions, 0 deletions
diff --git a/utils/GraphUtils.h b/utils/GraphUtils.h
index d7f24afdd8..131378e5bd 100644
--- a/utils/GraphUtils.h
+++ b/utils/GraphUtils.h
@@ -283,6 +283,36 @@ private:
size_t _positive_samples_top5;
};
+/** Detection output accessor class */
+class DetectionOutputAccessor final : public graph::ITensorAccessor
+{
+public:
+ /** Constructor
+ *
+ * @param[in] labels_path Path to labels text file.
+ * @param[in] imgs_tensor_shapes Network input images tensor shapes.
+ * @param[out] output_stream (Optional) Output stream
+ */
+ DetectionOutputAccessor(const std::string &labels_path, std::vector<TensorShape> &imgs_tensor_shapes, std::ostream &output_stream = std::cout);
+ /** Allow instances of this class to be move constructed */
+ DetectionOutputAccessor(DetectionOutputAccessor &&) = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ DetectionOutputAccessor(const DetectionOutputAccessor &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ DetectionOutputAccessor &operator=(const DetectionOutputAccessor &) = delete;
+
+ // Inherited methods overriden:
+ bool access_tensor(ITensor &tensor) override;
+
+private:
+ template <typename T>
+ void access_predictions_tensor(ITensor &tensor);
+
+ std::vector<std::string> _labels;
+ std::vector<TensorShape> _tensor_shapes;
+ std::ostream &_output_stream;
+};
+
/** Result accessor class */
class TopNPredictionsAccessor final : public graph::ITensorAccessor
{
@@ -472,6 +502,39 @@ inline std::unique_ptr<graph::ITensorAccessor> get_output_accessor(const arm_com
return arm_compute::support::cpp14::make_unique<TopNPredictionsAccessor>(graph_parameters.labels, top_n, output_stream);
}
}
+/** Generates appropriate output accessor according to the specified graph parameters
+ *
+ * @note If the output accessor is requested to validate the graph then ValidationOutputAccessor is generated
+ * else if output_accessor_file is empty will generate a DummyAccessor else will generate a TopNPredictionsAccessor
+ *
+ * @param[in] graph_parameters Graph parameters
+ * @param[in] tensor_shapes Network input images tensor shapes.
+ * @param[in] is_validation (Optional) Validation flag (default = false)
+ * @param[out] output_stream (Optional) Output stream (default = std::cout)
+ *
+ * @return An appropriate tensor accessor
+ */
+inline std::unique_ptr<graph::ITensorAccessor> get_detection_output_accessor(const arm_compute::utils::CommonGraphParams &graph_parameters,
+ std::vector<TensorShape> tensor_shapes,
+ bool is_validation = false,
+ std::ostream &output_stream = std::cout)
+{
+ if(!graph_parameters.validation_file.empty())
+ {
+ return arm_compute::support::cpp14::make_unique<ValidationOutputAccessor>(graph_parameters.validation_file,
+ output_stream,
+ graph_parameters.validation_range_start,
+ graph_parameters.validation_range_end);
+ }
+ else if(graph_parameters.labels.empty())
+ {
+ return arm_compute::support::cpp14::make_unique<DummyAccessor>(0);
+ }
+ else
+ {
+ return arm_compute::support::cpp14::make_unique<DetectionOutputAccessor>(graph_parameters.labels, tensor_shapes, output_stream);
+ }
+}
/** Generates appropriate npy output accessor according to the specified npy_path
*
* @note If npy_path is empty will generate a DummyAccessor else will generate a NpyAccessor