summaryrefslogtreecommitdiff
path: root/source/use_case/img_class
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/img_class')
-rw-r--r--source/use_case/img_class/include/ImgClassProcessing.hpp23
-rw-r--r--source/use_case/img_class/src/ImgClassProcessing.cc8
-rw-r--r--source/use_case/img_class/src/UseCaseHandler.cc17
3 files changed, 39 insertions, 9 deletions
diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp
index 5a59b5f..59db4a5 100644
--- a/source/use_case/img_class/include/ImgClassProcessing.hpp
+++ b/source/use_case/img_class/include/ImgClassProcessing.hpp
@@ -32,8 +32,19 @@ namespace app {
class ImgClassPreProcess : public BasePreProcess {
public:
+ /**
+ * @brief Constructor
+ * @param[in] model Pointer to the the Image classification Model object.
+ **/
explicit ImgClassPreProcess(Model* model);
+ /**
+ * @brief Should perform pre-processing of 'raw' input image data and load it into
+ * TFLite Micro input tensors ready for inference
+ * @param[in] input Pointer to the data that pre-processing will work on.
+ * @param[in] inputSize Size of the input data.
+ * @return true if successful, false otherwise.
+ **/
bool DoPreProcess(const void* input, size_t inputSize) override;
};
@@ -50,10 +61,22 @@ namespace app {
std::vector<ClassificationResult>& m_results;
public:
+ /**
+ * @brief Constructor
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] model Pointer to the the Image classification Model object.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in] results Vector of classification results to store decoded outputs.
+ **/
ImgClassPostProcess(Classifier& classifier, Model* model,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results);
+ /**
+ * @brief Should perform post-processing of the result of inference then populate
+ * populate classification result data for any later use.
+ * @return true if successful, false otherwise.
+ **/
bool DoPostProcess() override;
};
diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc
index e33e3c1..6ba88ad 100644
--- a/source/use_case/img_class/src/ImgClassProcessing.cc
+++ b/source/use_case/img_class/src/ImgClassProcessing.cc
@@ -23,6 +23,9 @@ namespace app {
ImgClassPreProcess::ImgClassPreProcess(Model* model)
{
+ if (!model->IsInited()) {
+ printf_err("Model is not initialised!.\n");
+ }
this->m_model = model;
}
@@ -35,7 +38,7 @@ namespace app {
auto input = static_cast<const uint8_t*>(data);
TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0);
- memcpy(inputTensor->data.data, input, inputSize);
+ std::memcpy(inputTensor->data.data, input, inputSize);
debug("Input tensor populated \n");
if (this->m_model->IsDataSigned()) {
@@ -52,6 +55,9 @@ namespace app {
m_labels{labels},
m_results{results}
{
+ if (!model->IsInited()) {
+ printf_err("Model is not initialised!.\n");
+ }
this->m_model = model;
}
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc
index 98e2b59..11a1aa8 100644
--- a/source/use_case/img_class/src/UseCaseHandler.cc
+++ b/source/use_case/img_class/src/UseCaseHandler.cc
@@ -37,6 +37,12 @@ namespace app {
{
auto& profiler = ctx.Get<Profiler&>("profiler");
auto& model = ctx.Get<Model&>("model");
+ /* If the request has a valid size, set the image index as it might not be set. */
+ if (imgIndex < NUMBER_OF_FILES) {
+ if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) {
+ return false;
+ }
+ }
auto initialImIdx = ctx.Get<uint32_t>("imgIndex");
constexpr uint32_t dataPsnImgDownscaleFactor = 2;
@@ -46,12 +52,7 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartX = 150;
constexpr uint32_t dataPsnTxtInfStartY = 40;
- /* If the request has a valid size, set the image index. */
- if (imgIndex < NUMBER_OF_FILES) {
- if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) {
- return false;
- }
- }
+
if (!model.IsInited()) {
printf_err("Model is not initialised! Terminating processing.\n");
return false;
@@ -102,7 +103,7 @@ namespace app {
/* Display message on the LCD - inference running. */
hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
/* Select the image to run inference with. */
info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"),
@@ -129,7 +130,7 @@ namespace app {
/* Erase. */
str_inf = std::string(str_inf.size(), ' ');
hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
- dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
+ dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
/* Add results to context for access outside handler. */
ctx.Set<std::vector<ClassificationResult>>("results", results);