summaryrefslogtreecommitdiff
path: root/source/use_case/img_class
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
commitb40ecf8522052809d2351677a96195d69e4d0c16 (patch)
tree8647dfdae7bcae0ec6d9564ba7a971819fdda431 /source/use_case/img_class
parentc291144b7f08c21d08cdaf79cc64dc420ca70070 (diff)
downloadml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments Removed the inference runner code as it wasn't used Fixes to doc strings Consistent naming e.g. Asr/Kws instead of ASR/KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'source/use_case/img_class')
-rw-r--r--source/use_case/img_class/include/ImgClassProcessing.hpp32
-rw-r--r--source/use_case/img_class/src/ImgClassProcessing.cc33
-rw-r--r--source/use_case/img_class/src/UseCaseHandler.cc22
3 files changed, 43 insertions, 44 deletions
diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp
index 59db4a5..e931b7d 100644
--- a/source/use_case/img_class/include/ImgClassProcessing.hpp
+++ b/source/use_case/img_class/include/ImgClassProcessing.hpp
@@ -34,9 +34,10 @@ namespace app {
public:
/**
* @brief Constructor
- * @param[in] model Pointer to the the Image classification Model object.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
+ * @param[in] convertToInt8 Should the image be converted to Int8 range.
**/
- explicit ImgClassPreProcess(Model* model);
+ explicit ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8);
/**
* @brief Should perform pre-processing of 'raw' input image data and load it into
@@ -46,6 +47,10 @@ namespace app {
* @return true if successful, false otherwise.
**/
bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ private:
+ TfLiteTensor* m_inputTensor;
+ bool m_convertToInt8;
};
/**
@@ -55,29 +60,30 @@ namespace app {
*/
class ImgClassPostProcess : public BasePostProcess {
- private:
- Classifier& m_imgClassifier;
- const std::vector<std::string>& m_labels;
- std::vector<ClassificationResult>& m_results;
-
public:
/**
* @brief Constructor
- * @param[in] classifier Classifier object used to get top N results from classification.
- * @param[in] model Pointer to the the Image classification Model object.
- * @param[in] labels Vector of string labels to identify each output of the model.
- * @param[in] results Vector of classification results to store decoded outputs.
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[in] results Vector of classification results to store decoded outputs.
**/
- ImgClassPostProcess(Classifier& classifier, Model* model,
+ ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results);
/**
- * @brief Should perform post-processing of the result of inference then populate
+ * @brief Should perform post-processing of the result of inference then
* populate classification result data for any later use.
* @return true if successful, false otherwise.
**/
bool DoPostProcess() override;
+
+ private:
+ TfLiteTensor* m_outputTensor;
+ Classifier& m_imgClassifier;
+ const std::vector<std::string>& m_labels;
+ std::vector<ClassificationResult>& m_results;
};
} /* namespace app */
diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc
index 6ba88ad..adf9794 100644
--- a/source/use_case/img_class/src/ImgClassProcessing.cc
+++ b/source/use_case/img_class/src/ImgClassProcessing.cc
@@ -21,50 +21,43 @@
namespace arm {
namespace app {
- ImgClassPreProcess::ImgClassPreProcess(Model* model)
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ ImgClassPreProcess::ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8)
+ :m_inputTensor{inputTensor},
+ m_convertToInt8{convertToInt8}
+ {}
bool ImgClassPreProcess::DoPreProcess(const void* data, size_t inputSize)
{
if (data == nullptr) {
printf_err("Data pointer is null");
+ return false;
}
auto input = static_cast<const uint8_t*>(data);
- TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0);
- std::memcpy(inputTensor->data.data, input, inputSize);
+ std::memcpy(this->m_inputTensor->data.data, input, inputSize);
debug("Input tensor populated \n");
- if (this->m_model->IsDataSigned()) {
- image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
+ if (this->m_convertToInt8) {
+ image::ConvertImgToInt8(this->m_inputTensor->data.data, this->m_inputTensor->bytes);
}
return true;
}
- ImgClassPostProcess::ImgClassPostProcess(Classifier& classifier, Model* model,
+ ImgClassPostProcess::ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results)
- :m_imgClassifier{classifier},
+ :m_outputTensor{outputTensor},
+ m_imgClassifier{classifier},
m_labels{labels},
m_results{results}
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ {}
bool ImgClassPostProcess::DoPostProcess()
{
return this->m_imgClassifier.GetClassificationResults(
- this->m_model->GetOutputTensor(0), this->m_results,
+ this->m_outputTensor, this->m_results,
this->m_labels, 5, false);
}
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc
index c68d816..5cc3959 100644
--- a/source/use_case/img_class/src/UseCaseHandler.cc
+++ b/source/use_case/img_class/src/UseCaseHandler.cc
@@ -59,6 +59,7 @@ namespace app {
}
TfLiteTensor* inputTensor = model.GetInputTensor(0);
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return false;
@@ -74,13 +75,12 @@ namespace app {
const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx];
/* Set up pre and post-processing. */
- ImgClassPreProcess preprocess = ImgClassPreProcess(&model);
+ ImgClassPreProcess preProcess = ImgClassPreProcess(inputTensor, model.IsDataSigned());
std::vector<ClassificationResult> results;
- ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model,
- ctx.Get<std::vector<std::string>&>("labels"), results);
-
- UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
+ ImgClassPostProcess postProcess = ImgClassPostProcess(outputTensor,
+ ctx.Get<ImgClassClassifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"),
+ results);
do {
hal_lcd_clear(COLOR_BLACK);
@@ -113,17 +113,18 @@ namespace app {
inputTensor->bytes : IMAGE_DATA_SIZE;
/* Run the pre-processing, inference and post-processing. */
- if (!runner.PreProcess(imgSrc, imgSz)) {
+ if (!preProcess.DoPreProcess(imgSrc, imgSz)) {
+ printf_err("Pre-processing failed.");
return false;
}
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
@@ -136,7 +137,6 @@ namespace app {
ctx.Set<std::vector<ClassificationResult>>("results", results);
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
arm::app::DumpTensor(outputTensor);
#endif /* VERIFY_TEST_OUTPUT */