diff options
Diffstat (limited to 'utils/GraphUtils.h')
-rw-r--r-- | utils/GraphUtils.h | 45 |
1 files changed, 44 insertions, 1 deletions
diff --git a/utils/GraphUtils.h b/utils/GraphUtils.h index 349d8558fc..768c608d26 100644 --- a/utils/GraphUtils.h +++ b/utils/GraphUtils.h @@ -185,7 +185,7 @@ public: * @param[in] start (Optional) Start range * @param[in] end (Optional) End range * - * @note + * @note Range is defined as [start, end] */ ValidationInputAccessor(const std::string &image_list, std::string images_path, @@ -203,6 +203,49 @@ private: size_t _offset; }; +/** Output Accessor used for network validation */ +class ValidationOutputAccessor final : public graph::ITensorAccessor +{ +public: + /** Default Constructor + * + * @param[in] image_list File containing all the images and labels results + * @param[in] top_n (Optional) Top N accuracy (Defaults to 5) + * @param[out] output_stream (Optional) Output stream (Defaults to the standard output stream) + * @param[in] start (Optional) Start range + * @param[in] end (Optional) End range + * + * @note Range is defined as [start, end] + */ + ValidationOutputAccessor(const std::string &image_list, + size_t top_n = 5, + std::ostream &output_stream = std::cout, + unsigned int start = 0, + unsigned int end = 0); + /** Reset accessor state */ + void reset(); + + // Inherited methods overriden: + bool access_tensor(ITensor &tensor) override; + +private: + /** Access predictions of the tensor + * + * @tparam T Tensor elements type + * + * @param[in] tensor Tensor to read the predictions from + */ + template <typename T> + std::vector<size_t> access_predictions_tensor(ITensor &tensor); + +private: + std::vector<int> _results; + std::ostream &_output_stream; + size_t _top_n; + size_t _offset; + size_t _positive_samples; +}; + /** Result accessor class */ class TopNPredictionsAccessor final : public graph::ITensorAccessor { |