diff options
Diffstat (limited to 'source/application/main/include/Classifier.hpp')
-rw-r--r-- | source/application/main/include/Classifier.hpp | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/source/application/main/include/Classifier.hpp b/source/application/main/include/Classifier.hpp index 3ee3148..d899e8e 100644 --- a/source/application/main/include/Classifier.hpp +++ b/source/application/main/include/Classifier.hpp @@ -42,18 +42,33 @@ namespace app { * populated by this function. * @param[in] labels Labels vector to match classified classes. * @param[in] topNCount Number of top classifications to pick. Default is 1. + * @param[in] useSoftmax Whether Softmax normalisation should be applied to output. Default is false. * @return true if successful, false otherwise. **/ + virtual bool GetClassificationResults( TfLiteTensor* outputTensor, std::vector<ClassificationResult>& vecResults, - const std::vector <std::string>& labels, uint32_t topNCount); + const std::vector <std::string>& labels, uint32_t topNCount, + bool use_softmax = false); + + /** + * @brief Populate the elements of the Classification Result object. + * @param[in] topNSet Ordered set of top 5 output class scores and labels. + * @param[out] vecResults A vector of classification results. + * populated by this function. + * @param[in] labels Labels vector to match classified classes. + **/ + + void SetVectorResults( + std::set<std::pair<float, uint32_t>>& topNSet, + std::vector<ClassificationResult>& vecResults, + const std::vector <std::string>& labels); private: /** * @brief Utility function that gets the top N classification results from the * output vector. - * @tparam T value type * @param[in] tensor Inference output tensor from an NN model. * @param[out] vecResults A vector of classification results * populated by this function. @@ -61,8 +76,8 @@ namespace app { * @param[in] labels Labels vector to match classified classes. * @return true if successful, false otherwise. **/ - template<typename T> - bool GetTopNResults(TfLiteTensor* tensor, + + bool GetTopNResults(const std::vector<float>& tensor, std::vector<ClassificationResult>& vecResults, uint32_t topNCount, const std::vector <std::string>& labels); |