diff options
Diffstat (limited to 'source/use_case/vww')
-rw-r--r-- | source/use_case/vww/include/VisualWakeWordProcessing.hpp | 25 | ||||
-rw-r--r-- | source/use_case/vww/src/UseCaseHandler.cc | 22 | ||||
-rw-r--r-- | source/use_case/vww/src/VisualWakeWordProcessing.cc | 33 |
3 files changed, 37 insertions, 43 deletions
diff --git a/source/use_case/vww/include/VisualWakeWordProcessing.hpp b/source/use_case/vww/include/VisualWakeWordProcessing.hpp index b1d68ce..bef161f 100644 --- a/source/use_case/vww/include/VisualWakeWordProcessing.hpp +++ b/source/use_case/vww/include/VisualWakeWordProcessing.hpp @@ -34,9 +34,9 @@ namespace app { public: /** * @brief Constructor - * @param[in] model Pointer to the the Image classification Model object. + * @param[in] inputTensor Pointer to the TFLite Micro input Tensor. **/ - explicit VisualWakeWordPreProcess(Model* model); + explicit VisualWakeWordPreProcess(TfLiteTensor* inputTensor); /** * @brief Should perform pre-processing of 'raw' input image data and load it into @@ -46,6 +46,9 @@ namespace app { * @return true if successful, false otherwise. **/ bool DoPreProcess(const void* input, size_t inputSize) override; + + private: + TfLiteTensor* m_inputTensor; }; /** @@ -56,6 +59,7 @@ namespace app { class VisualWakeWordPostProcess : public BasePostProcess { private: + TfLiteTensor* m_outputTensor; Classifier& m_vwwClassifier; const std::vector<std::string>& m_labels; std::vector<ClassificationResult>& m_results; @@ -63,19 +67,20 @@ namespace app { public: /** * @brief Constructor - * @param[in] classifier Classifier object used to get top N results from classification. - * @param[in] model Pointer to the VWW classification Model object. - * @param[in] labels Vector of string labels to identify each output of the model. - * @param[out] results Vector of classification results to store decoded outputs. + * @param[in] outputTensor Pointer to the TFLite Micro output Tensor. + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] model Pointer to the VWW classification Model object. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[out] results Vector of classification results to store decoded outputs. **/ - VisualWakeWordPostProcess(Classifier& classifier, Model* model, + VisualWakeWordPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results); /** - * @brief Should perform post-processing of the result of inference then - * populate classification result data for any later use. - * @return true if successful, false otherwise. + * @brief Should perform post-processing of the result of inference then + * populate classification result data for any later use. + * @return true if successful, false otherwise. **/ bool DoPostProcess() override; }; diff --git a/source/use_case/vww/src/UseCaseHandler.cc b/source/use_case/vww/src/UseCaseHandler.cc index 7681f89..267e6c4 100644 --- a/source/use_case/vww/src/UseCaseHandler.cc +++ b/source/use_case/vww/src/UseCaseHandler.cc @@ -53,7 +53,7 @@ namespace app { } TfLiteTensor* inputTensor = model.GetInputTensor(0); - + TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; @@ -75,15 +75,13 @@ namespace app { const uint32_t displayChannels = 3; /* Set up pre and post-processing. */ - VisualWakeWordPreProcess preprocess = VisualWakeWordPreProcess(&model); + VisualWakeWordPreProcess preProcess = VisualWakeWordPreProcess(inputTensor); std::vector<ClassificationResult> results; - VisualWakeWordPostProcess postprocess = VisualWakeWordPostProcess( - ctx.Get<Classifier&>("classifier"), &model, + VisualWakeWordPostProcess postProcess = VisualWakeWordPostProcess(outputTensor, + ctx.Get<Classifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"), results); - UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); - do { hal_lcd_clear(COLOR_BLACK); @@ -115,17 +113,18 @@ namespace app { inputTensor->bytes : IMAGE_DATA_SIZE; /* Run the pre-processing, inference and post-processing. */ - if (!runner.PreProcess(imgSrc, imgSz)) { + if (!preProcess.DoPreProcess(imgSrc, imgSz)) { + printf_err("Pre-processing failed."); return false; } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } @@ -138,7 +137,6 @@ namespace app { ctx.Set<std::vector<ClassificationResult>>("results", results); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ diff --git a/source/use_case/vww/src/VisualWakeWordProcessing.cc b/source/use_case/vww/src/VisualWakeWordProcessing.cc index 94eae28..a9863c0 100644 --- a/source/use_case/vww/src/VisualWakeWordProcessing.cc +++ b/source/use_case/vww/src/VisualWakeWordProcessing.cc @@ -22,13 +22,9 @@ namespace arm { namespace app { - VisualWakeWordPreProcess::VisualWakeWordPreProcess(Model* model) - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + VisualWakeWordPreProcess::VisualWakeWordPreProcess(TfLiteTensor* inputTensor) + :m_inputTensor{inputTensor} + {} bool VisualWakeWordPreProcess::DoPreProcess(const void* data, size_t inputSize) { @@ -37,9 +33,8 @@ namespace app { } auto input = static_cast<const uint8_t*>(data); - TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0); - auto unsignedDstPtr = static_cast<uint8_t*>(inputTensor->data.data); + auto unsignedDstPtr = static_cast<uint8_t*>(this->m_inputTensor->data.data); /* VWW model has one channel input => Convert image to grayscale here. * We expect images to always be RGB. */ @@ -47,10 +42,10 @@ namespace app { /* VWW model pre-processing is image conversion from uint8 to [0,1] float values, * then quantize them with input quantization info. */ - QuantParams inQuantParams = GetTensorQuantParams(inputTensor); + QuantParams inQuantParams = GetTensorQuantParams(this->m_inputTensor); - auto signedDstPtr = static_cast<int8_t*>(inputTensor->data.data); - for (size_t i = 0; i < inputTensor->bytes; i++) { + auto signedDstPtr = static_cast<int8_t*>(this->m_inputTensor->data.data); + for (size_t i = 0; i < this->m_inputTensor->bytes; i++) { auto i_data_int8 = static_cast<int8_t>( ((static_cast<float>(unsignedDstPtr[i]) / 255.0f) / inQuantParams.scale) + inQuantParams.offset ); @@ -62,22 +57,18 @@ namespace app { return true; } - VisualWakeWordPostProcess::VisualWakeWordPostProcess(Classifier& classifier, Model* model, + VisualWakeWordPostProcess::VisualWakeWordPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results) - :m_vwwClassifier{classifier}, + :m_outputTensor{outputTensor}, + m_vwwClassifier{classifier}, m_labels{labels}, m_results{results} - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + {} bool VisualWakeWordPostProcess::DoPostProcess() { return this->m_vwwClassifier.GetClassificationResults( - this->m_model->GetOutputTensor(0), this->m_results, + this->m_outputTensor, this->m_results, this->m_labels, 1, true); } |