summaryrefslogtreecommitdiff
path: root/source/use_case/img_class
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/img_class')
-rw-r--r--source/use_case/img_class/include/ImgClassProcessing.hpp32
-rw-r--r--source/use_case/img_class/src/ImgClassProcessing.cc33
-rw-r--r--source/use_case/img_class/src/UseCaseHandler.cc22
3 files changed, 43 insertions, 44 deletions
diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp
index 59db4a5..e931b7d 100644
--- a/source/use_case/img_class/include/ImgClassProcessing.hpp
+++ b/source/use_case/img_class/include/ImgClassProcessing.hpp
@@ -34,9 +34,10 @@ namespace app {
public:
/**
* @brief Constructor
- * @param[in] model Pointer to the the Image classification Model object.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] convertToInt8 Should the image be converted to Int8 range.
**/
- explicit ImgClassPreProcess(Model* model);
+ explicit ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8);
/**
* @brief Should perform pre-processing of 'raw' input image data and load it into
@@ -46,6 +47,10 @@ namespace app {
* @return true if successful, false otherwise.
**/
bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ private:
+ TfLiteTensor* m_inputTensor;
+ bool m_convertToInt8;
};
/**
@@ -55,29 +60,30 @@ namespace app {
*/
class ImgClassPostProcess : public BasePostProcess {
- private:
- Classifier& m_imgClassifier;
- const std::vector<std::string>& m_labels;
- std::vector<ClassificationResult>& m_results;
-
public:
/**
* @brief Constructor
- * @param[in] classifier Classifier object used to get top N results from classification.
- * @param[in] model Pointer to the the Image classification Model object.
- * @param[in] labels Vector of string labels to identify each output of the model.
- * @param[in] results Vector of classification results to store decoded outputs.
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in] results Vector of classification results to store decoded outputs.
**/
- ImgClassPostProcess(Classifier& classifier, Model* model,
+ ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results);
/**
- * @brief Should perform post-processing of the result of inference then populate
+ * @brief Should perform post-processing of the result of inference then
* populate classification result data for any later use.
* @return true if successful, false otherwise.
**/
bool DoPostProcess() override;
+
+ private:
+ TfLiteTensor* m_outputTensor;
+ Classifier& m_imgClassifier;
+ const std::vector<std::string>& m_labels;
+ std::vector<ClassificationResult>& m_results;
};
} /* namespace app */
diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc
index 6ba88ad..adf9794 100644
--- a/source/use_case/img_class/src/ImgClassProcessing.cc
+++ b/source/use_case/img_class/src/ImgClassProcessing.cc
@@ -21,50 +21,43 @@
namespace arm {
namespace app {
- ImgClassPreProcess::ImgClassPreProcess(Model* model)
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ ImgClassPreProcess::ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8)
+ :m_inputTensor{inputTensor},
+ m_convertToInt8{convertToInt8}
+ {}
bool ImgClassPreProcess::DoPreProcess(const void* data, size_t inputSize)
{
if (data == nullptr) {
printf_err("Data pointer is null");
+ return false;
}
auto input = static_cast<const uint8_t*>(data);
- TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0);
- std::memcpy(inputTensor->data.data, input, inputSize);
+ std::memcpy(this->m_inputTensor->data.data, input, inputSize);
debug("Input tensor populated \n");
- if (this->m_model->IsDataSigned()) {
- image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
+ if (this->m_convertToInt8) {
+ image::ConvertImgToInt8(this->m_inputTensor->data.data, this->m_inputTensor->bytes);
}
return true;
}
- ImgClassPostProcess::ImgClassPostProcess(Classifier& classifier, Model* model,
+ ImgClassPostProcess::ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results)
- :m_imgClassifier{classifier},
+ :m_outputTensor{outputTensor},
+ m_imgClassifier{classifier},
m_labels{labels},
m_results{results}
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ {}
bool ImgClassPostProcess::DoPostProcess()
{
return this->m_imgClassifier.GetClassificationResults(
- this->m_model->GetOutputTensor(0), this->m_results,
+ this->m_outputTensor, this->m_results,
this->m_labels, 5, false);
}
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc
index c68d816..5cc3959 100644
--- a/source/use_case/img_class/src/UseCaseHandler.cc
+++ b/source/use_case/img_class/src/UseCaseHandler.cc
@@ -59,6 +59,7 @@ namespace app {
}
TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return false;
@@ -74,13 +75,12 @@ namespace app {
const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
/* Set up pre and post-processing. */
- ImgClassPreProcess preprocess = ImgClassPreProcess(&model);
+ ImgClassPreProcess preProcess = ImgClassPreProcess(inputTensor, model.IsDataSigned());
std::vector<ClassificationResult> results;
- ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model,
- ctx.Get<std::vector<std::string>&>("labels"), results);
-
- UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
+ ImgClassPostProcess postProcess = ImgClassPostProcess(outputTensor,
+ ctx.Get<ImgClassClassifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"),
+ results);
do {
hal_lcd_clear(COLOR_BLACK);
@@ -113,17 +113,18 @@ namespace app {
inputTensor->bytes : IMAGE_DATA_SIZE;
/* Run the pre-processing, inference and post-processing. */
- if (!runner.PreProcess(imgSrc, imgSz)) {
+ if (!preProcess.DoPreProcess(imgSrc, imgSz)) {
+ printf_err("Pre-processing failed.");
return false;
}
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
@@ -136,7 +137,6 @@ namespace app {
ctx.Set<std::vector<ClassificationResult>>("results", results);
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
arm::app::DumpTensor(outputTensor);
#endif /* VERIFY_TEST_OUTPUT */