diff options
Diffstat (limited to 'utils/GraphUtils.h')
-rw-r--r-- | utils/GraphUtils.h | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/utils/GraphUtils.h b/utils/GraphUtils.h index d7f24afdd8..131378e5bd 100644 --- a/utils/GraphUtils.h +++ b/utils/GraphUtils.h @@ -283,6 +283,36 @@ private: size_t _positive_samples_top5; }; +/** Detection output accessor class */ +class DetectionOutputAccessor final : public graph::ITensorAccessor +{ +public: + /** Constructor + * + * @param[in] labels_path Path to labels text file. + * @param[in] imgs_tensor_shapes Network input images tensor shapes. + * @param[out] output_stream (Optional) Output stream + */ + DetectionOutputAccessor(const std::string &labels_path, std::vector<TensorShape> &imgs_tensor_shapes, std::ostream &output_stream = std::cout); + /** Allow instances of this class to be move constructed */ + DetectionOutputAccessor(DetectionOutputAccessor &&) = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + DetectionOutputAccessor(const DetectionOutputAccessor &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + DetectionOutputAccessor &operator=(const DetectionOutputAccessor &) = delete; + + // Inherited methods overriden: + bool access_tensor(ITensor &tensor) override; + +private: + template <typename T> + void access_predictions_tensor(ITensor &tensor); + + std::vector<std::string> _labels; + std::vector<TensorShape> _tensor_shapes; + std::ostream &_output_stream; +}; + /** Result accessor class */ class TopNPredictionsAccessor final : public graph::ITensorAccessor { @@ -472,6 +502,39 @@ inline std::unique_ptr<graph::ITensorAccessor> get_output_accessor(const arm_com return arm_compute::support::cpp14::make_unique<TopNPredictionsAccessor>(graph_parameters.labels, top_n, output_stream); } } +/** Generates appropriate output accessor according to the specified graph parameters + * + * @note If the output accessor is requested to validate the graph then ValidationOutputAccessor is generated + * else if output_accessor_file is empty will generate a DummyAccessor else will generate a TopNPredictionsAccessor + * + * @param[in] graph_parameters Graph parameters + * @param[in] tensor_shapes Network input images tensor shapes. + * @param[in] is_validation (Optional) Validation flag (default = false) + * @param[out] output_stream (Optional) Output stream (default = std::cout) + * + * @return An appropriate tensor accessor + */ +inline std::unique_ptr<graph::ITensorAccessor> get_detection_output_accessor(const arm_compute::utils::CommonGraphParams &graph_parameters, + std::vector<TensorShape> tensor_shapes, + bool is_validation = false, + std::ostream &output_stream = std::cout) +{ + if(!graph_parameters.validation_file.empty()) + { + return arm_compute::support::cpp14::make_unique<ValidationOutputAccessor>(graph_parameters.validation_file, + output_stream, + graph_parameters.validation_range_start, + graph_parameters.validation_range_end); + } + else if(graph_parameters.labels.empty()) + { + return arm_compute::support::cpp14::make_unique<DummyAccessor>(0); + } + else + { + return arm_compute::support::cpp14::make_unique<DetectionOutputAccessor>(graph_parameters.labels, tensor_shapes, output_stream); + } +} /** Generates appropriate npy output accessor according to the specified npy_path * * @note If npy_path is empty will generate a DummyAccessor else will generate a NpyAccessor |