summaryrefslogtreecommitdiff
path: root/source/use_case/img_class/include/ImgClassProcessing.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/img_class/include/ImgClassProcessing.hpp')
-rw-r--r--source/use_case/img_class/include/ImgClassProcessing.hpp32
1 files changed, 19 insertions, 13 deletions
diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp
index 59db4a5..e931b7d 100644
--- a/source/use_case/img_class/include/ImgClassProcessing.hpp
+++ b/source/use_case/img_class/include/ImgClassProcessing.hpp
@@ -34,9 +34,10 @@ namespace app {
public:
/**
* @brief Constructor
- * @param[in] model Pointer to the the Image classification Model object.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] convertToInt8 Should the image be converted to Int8 range.
**/
- explicit ImgClassPreProcess(Model* model);
+ explicit ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8);
/**
* @brief Should perform pre-processing of 'raw' input image data and load it into
@@ -46,6 +47,10 @@ namespace app {
* @return true if successful, false otherwise.
**/
bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ private:
+ TfLiteTensor* m_inputTensor;
+ bool m_convertToInt8;
};
/**
@@ -55,29 +60,30 @@ namespace app {
*/
class ImgClassPostProcess : public BasePostProcess {
- private:
- Classifier& m_imgClassifier;
- const std::vector<std::string>& m_labels;
- std::vector<ClassificationResult>& m_results;
-
public:
/**
* @brief Constructor
- * @param[in] classifier Classifier object used to get top N results from classification.
- * @param[in] model Pointer to the the Image classification Model object.
- * @param[in] labels Vector of string labels to identify each output of the model.
- * @param[in] results Vector of classification results to store decoded outputs.
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in] results Vector of classification results to store decoded outputs.
**/
- ImgClassPostProcess(Classifier& classifier, Model* model,
+ ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results);
/**
- * @brief Should perform post-processing of the result of inference then populate
+ * @brief Should perform post-processing of the result of inference then
* populate classification result data for any later use.
* @return true if successful, false otherwise.
**/
bool DoPostProcess() override;
+
+ private:
+ TfLiteTensor* m_outputTensor;
+ Classifier& m_imgClassifier;
+ const std::vector<std::string>& m_labels;
+ std::vector<ClassificationResult>& m_results;
};
} /* namespace app */