diff options
Diffstat (limited to 'source/use_case/img_class/include/ImgClassProcessing.hpp')
-rw-r--r-- | source/use_case/img_class/include/ImgClassProcessing.hpp | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp index 59db4a5..e931b7d 100644 --- a/source/use_case/img_class/include/ImgClassProcessing.hpp +++ b/source/use_case/img_class/include/ImgClassProcessing.hpp @@ -34,9 +34,10 @@ namespace app { public: /** * @brief Constructor - * @param[in] model Pointer to the the Image classification Model object. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. + * @param[in] convertToInt8 Should the image be converted to Int8 range. **/ - explicit ImgClassPreProcess(Model* model); + explicit ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8); /** * @brief Should perform pre-processing of 'raw' input image data and load it into @@ -46,6 +47,10 @@ namespace app { * @return true if successful, false otherwise. **/ bool DoPreProcess(const void* input, size_t inputSize) override; + + private: + TfLiteTensor* m_inputTensor; + bool m_convertToInt8; }; /** @@ -55,29 +60,30 @@ namespace app { */ class ImgClassPostProcess : public BasePostProcess { - private: - Classifier& m_imgClassifier; - const std::vector<std::string>& m_labels; - std::vector<ClassificationResult>& m_results; - public: /** * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the the Image classification Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[in] results Vector of classification results to store decoded outputs. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[in] results Vector of classification results to store decoded outputs. **/ - ImgClassPostProcess(Classifier& classifier, Model* model, + ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results); /** - * @brief Should perform post-processing of the result of inference then populate + * @brief Should perform post-processing of the result of inference then * populate classification result data for any later use. * @return true if successful, false otherwise. **/ bool DoPostProcess() override; + + private: + TfLiteTensor* m_outputTensor; + Classifier& m_imgClassifier; + const std::vector<std::string>& m_labels; + std::vector<ClassificationResult>& m_results; }; } /* namespace app */ |