summaryrefslogtreecommitdiff
path: root/source/use_case/vww
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
commitb40ecf8522052809d2351677a96195d69e4d0c16 (patch)
tree8647dfdae7bcae0ec6d9564ba7a971819fdda431 /source/use_case/vww
parentc291144b7f08c21d08cdaf79cc64dc420ca70070 (diff)
downloadml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments Removed the inference runner code as it wasn't used Fixes to doc strings Consistent naming e.g. Asr/Kws instead of ASR/KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'source/use_case/vww')
-rw-r--r--source/use_case/vww/include/VisualWakeWordProcessing.hpp25
-rw-r--r--source/use_case/vww/src/UseCaseHandler.cc22
-rw-r--r--source/use_case/vww/src/VisualWakeWordProcessing.cc33
3 files changed, 37 insertions, 43 deletions
diff --git a/source/use_case/vww/include/VisualWakeWordProcessing.hpp b/source/use_case/vww/include/VisualWakeWordProcessing.hpp
index b1d68ce..bef161f 100644
--- a/source/use_case/vww/include/VisualWakeWordProcessing.hpp
+++ b/source/use_case/vww/include/VisualWakeWordProcessing.hpp
@@ -34,9 +34,9 @@ namespace app {
public:
/**
* @brief Constructor
- * @param[in] model Pointer to the the Image classification Model object.
+ * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
**/
- explicit VisualWakeWordPreProcess(Model* model);
+ explicit VisualWakeWordPreProcess(TfLiteTensor* inputTensor);
/**
* @brief Should perform pre-processing of 'raw' input image data and load it into
@@ -46,6 +46,9 @@ namespace app {
* @return true if successful, false otherwise.
**/
bool DoPreProcess(const void* input, size_t inputSize) override;
+
+ private:
+ TfLiteTensor* m_inputTensor;
};
/**
@@ -56,6 +59,7 @@ namespace app {
class VisualWakeWordPostProcess : public BasePostProcess {
private:
+ TfLiteTensor* m_outputTensor;
Classifier& m_vwwClassifier;
const std::vector<std::string>& m_labels;
std::vector<ClassificationResult>& m_results;
@@ -63,19 +67,20 @@ namespace app {
public:
/**
* @brief Constructor
- * @param[in] classifier Classifier object used to get top N results from classification.
- * @param[in] model Pointer to the VWW classification Model object.
- * @param[in] labels Vector of string labels to identify each output of the model.
- * @param[out] results Vector of classification results to store decoded outputs.
+ * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
+ * @param[in] classifier Classifier object used to get top N results from classification.
+ * @param[in] model Pointer to the VWW classification Model object.
+ * @param[in] labels Vector of string labels to identify each output of the model.
+ * @param[out] results Vector of classification results to store decoded outputs.
**/
- VisualWakeWordPostProcess(Classifier& classifier, Model* model,
+ VisualWakeWordPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results);
/**
- * @brief Should perform post-processing of the result of inference then
- * populate classification result data for any later use.
- * @return true if successful, false otherwise.
+ * @brief Should perform post-processing of the result of inference then
+ * populate classification result data for any later use.
+ * @return true if successful, false otherwise.
**/
bool DoPostProcess() override;
};
diff --git a/source/use_case/vww/src/UseCaseHandler.cc b/source/use_case/vww/src/UseCaseHandler.cc
index 7681f89..267e6c4 100644
--- a/source/use_case/vww/src/UseCaseHandler.cc
+++ b/source/use_case/vww/src/UseCaseHandler.cc
@@ -53,7 +53,7 @@ namespace app {
}
TfLiteTensor* inputTensor = model.GetInputTensor(0);
-
+ TfLiteTensor* outputTensor = model.GetOutputTensor(0);
if (!inputTensor->dims) {
printf_err("Invalid input tensor dims\n");
return false;
@@ -75,15 +75,13 @@ namespace app {
const uint32_t displayChannels = 3;
/* Set up pre and post-processing. */
- VisualWakeWordPreProcess preprocess = VisualWakeWordPreProcess(&model);
+ VisualWakeWordPreProcess preProcess = VisualWakeWordPreProcess(inputTensor);
std::vector<ClassificationResult> results;
- VisualWakeWordPostProcess postprocess = VisualWakeWordPostProcess(
- ctx.Get<Classifier&>("classifier"), &model,
+ VisualWakeWordPostProcess postProcess = VisualWakeWordPostProcess(outputTensor,
+ ctx.Get<Classifier&>("classifier"),
ctx.Get<std::vector<std::string>&>("labels"), results);
- UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model);
-
do {
hal_lcd_clear(COLOR_BLACK);
@@ -115,17 +113,18 @@ namespace app {
inputTensor->bytes : IMAGE_DATA_SIZE;
/* Run the pre-processing, inference and post-processing. */
- if (!runner.PreProcess(imgSrc, imgSz)) {
+ if (!preProcess.DoPreProcess(imgSrc, imgSz)) {
+ printf_err("Pre-processing failed.");
return false;
}
- profiler.StartProfiling("Inference");
- if (!runner.RunInference()) {
+ if (!RunInference(model, profiler)) {
+ printf_err("Inference failed.");
return false;
}
- profiler.StopProfiling();
- if (!runner.PostProcess()) {
+ if (!postProcess.DoPostProcess()) {
+ printf_err("Post-processing failed.");
return false;
}
@@ -138,7 +137,6 @@ namespace app {
ctx.Set<std::vector<ClassificationResult>>("results", results);
#if VERIFY_TEST_OUTPUT
- TfLiteTensor* outputTensor = model.GetOutputTensor(0);
arm::app::DumpTensor(outputTensor);
#endif /* VERIFY_TEST_OUTPUT */
diff --git a/source/use_case/vww/src/VisualWakeWordProcessing.cc b/source/use_case/vww/src/VisualWakeWordProcessing.cc
index 94eae28..a9863c0 100644
--- a/source/use_case/vww/src/VisualWakeWordProcessing.cc
+++ b/source/use_case/vww/src/VisualWakeWordProcessing.cc
@@ -22,13 +22,9 @@
namespace arm {
namespace app {
- VisualWakeWordPreProcess::VisualWakeWordPreProcess(Model* model)
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ VisualWakeWordPreProcess::VisualWakeWordPreProcess(TfLiteTensor* inputTensor)
+ :m_inputTensor{inputTensor}
+ {}
bool VisualWakeWordPreProcess::DoPreProcess(const void* data, size_t inputSize)
{
@@ -37,9 +33,8 @@ namespace app {
}
auto input = static_cast<const uint8_t*>(data);
- TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0);
- auto unsignedDstPtr = static_cast<uint8_t*>(inputTensor->data.data);
+ auto unsignedDstPtr = static_cast<uint8_t*>(this->m_inputTensor->data.data);
/* VWW model has one channel input => Convert image to grayscale here.
* We expect images to always be RGB. */
@@ -47,10 +42,10 @@ namespace app {
/* VWW model pre-processing is image conversion from uint8 to [0,1] float values,
* then quantize them with input quantization info. */
- QuantParams inQuantParams = GetTensorQuantParams(inputTensor);
+ QuantParams inQuantParams = GetTensorQuantParams(this->m_inputTensor);
- auto signedDstPtr = static_cast<int8_t*>(inputTensor->data.data);
- for (size_t i = 0; i < inputTensor->bytes; i++) {
+ auto signedDstPtr = static_cast<int8_t*>(this->m_inputTensor->data.data);
+ for (size_t i = 0; i < this->m_inputTensor->bytes; i++) {
auto i_data_int8 = static_cast<int8_t>(
((static_cast<float>(unsignedDstPtr[i]) / 255.0f) / inQuantParams.scale) + inQuantParams.offset
);
@@ -62,22 +57,18 @@ namespace app {
return true;
}
- VisualWakeWordPostProcess::VisualWakeWordPostProcess(Classifier& classifier, Model* model,
+ VisualWakeWordPostProcess::VisualWakeWordPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels, std::vector<ClassificationResult>& results)
- :m_vwwClassifier{classifier},
+ :m_outputTensor{outputTensor},
+ m_vwwClassifier{classifier},
m_labels{labels},
m_results{results}
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ {}
bool VisualWakeWordPostProcess::DoPostProcess()
{
return this->m_vwwClassifier.GetClassificationResults(
- this->m_model->GetOutputTensor(0), this->m_results,
+ this->m_outputTensor, this->m_results,
this->m_labels, 1, true);
}