From c20be97af1c3e4d569d37587cd22d343193c7563 Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Tue, 19 Apr 2022 17:01:08 +0100 Subject: MLECO-3078: Add VWW use case API We now expect that img_class and vww use cases model inputs should have 4 dims as dictated by use case logic Signed-off-by: Richard Burton Change-Id: I67a57a3a28a7ff2e09c917c40e9fc2c08384a45c --- source/use_case/img_class/src/UseCaseHandler.cc | 8 +- .../vww/include/VisualWakeWordProcessing.hpp | 86 ++++++++++++ source/use_case/vww/src/UseCaseHandler.cc | 144 ++++++++------------- .../use_case/vww/src/VisualWakeWordProcessing.cc | 85 ++++++++++++ 4 files changed, 230 insertions(+), 93 deletions(-) create mode 100644 source/use_case/vww/include/VisualWakeWordProcessing.hpp create mode 100644 source/use_case/vww/src/VisualWakeWordProcessing.cc diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index 11a1aa8..c68d816 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -43,7 +43,7 @@ namespace app { return false; } } - auto initialImIdx = ctx.Get("imgIndex"); + auto initialImgIdx = ctx.Get("imgIndex"); constexpr uint32_t dataPsnImgDownscaleFactor = 2; constexpr uint32_t dataPsnImgStartX = 10; @@ -62,8 +62,8 @@ namespace app { if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; - } else if (inputTensor->dims->size < 3) { - printf_err("Input tensor dimension should be >= 3\n"); + } else if (inputTensor->dims->size < 4) { + printf_err("Input tensor dimension should be = 4\n"); return false; } @@ -148,7 +148,7 @@ namespace app { IncrementAppCtxIfmIdx(ctx,"imgIndex"); - } while (runAll && ctx.Get("imgIndex") != initialImIdx); + } while (runAll && ctx.Get("imgIndex") != initialImgIdx); return true; } diff --git a/source/use_case/vww/include/VisualWakeWordProcessing.hpp b/source/use_case/vww/include/VisualWakeWordProcessing.hpp new file mode 100644 index 0000000..b1d68ce --- /dev/null +++ b/source/use_case/vww/include/VisualWakeWordProcessing.hpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef VWW_PROCESSING_HPP +#define VWW_PROCESSING_HPP + +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "Classifier.hpp" + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for Visual Wake Word use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class VisualWakeWordPreProcess : public BasePreProcess { + + public: + /** + * @brief Constructor + * @param[in] model Pointer to the the Image classification Model object. + **/ + explicit VisualWakeWordPreProcess(Model* model); + + /** + * @brief Should perform pre-processing of 'raw' input image data and load it into + * TFLite Micro input tensors ready for inference + * @param[in] input Pointer to the data that pre-processing will work on. + * @param[in] inputSize Size of the input data. + * @return true if successful, false otherwise. + **/ + bool DoPreProcess(const void* input, size_t inputSize) override; + }; + + /** + * @brief Post-processing class for Visual Wake Word use case. + * Implements methods declared by BasePostProcess and anything else needed + * to populate result vector. + */ + class VisualWakeWordPostProcess : public BasePostProcess { + + private: + Classifier& m_vwwClassifier; + const std::vector& m_labels; + std::vector& m_results; + + public: + /** + * @brief Constructor + * @param[in] classifier Classifier object used to get top N results from classification. + * @param[in] model Pointer to the VWW classification Model object. + * @param[in] labels Vector of string labels to identify each output of the model. + * @param[out] results Vector of classification results to store decoded outputs. + **/ + VisualWakeWordPostProcess(Classifier& classifier, Model* model, + const std::vector& labels, + std::vector& results); + + /** + * @brief Should perform post-processing of the result of inference then + * populate classification result data for any later use. + * @return true if successful, false otherwise. + **/ + bool DoPostProcess() override; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* VWW_PROCESSING_HPP */ \ No newline at end of file diff --git a/source/use_case/vww/src/UseCaseHandler.cc b/source/use_case/vww/src/UseCaseHandler.cc index 56ba2b5..7681f89 100644 --- a/source/use_case/vww/src/UseCaseHandler.cc +++ b/source/use_case/vww/src/UseCaseHandler.cc @@ -22,27 +22,23 @@ #include "UseCaseCommonUtils.hpp" #include "hal.h" #include "log_macros.h" - -#include +#include "VisualWakeWordProcessing.hpp" namespace arm { namespace app { - /** - * @brief Helper function to load the current image into the input - * tensor. - * @param[in] imIdx Image index (from the pool of images available - * to the application). - * @param[out] inputTensor Pointer to the input tensor to be populated. - * @return true if tensor is loaded, false otherwise. - **/ - static bool LoadImageIntoTensor(uint32_t imIdx, - TfLiteTensor *inputTensor); - - /* Image inference classification handler. */ + /* Visual Wake Word inference handler. */ bool ClassifyImageHandler(ApplicationContext &ctx, uint32_t imgIndex, bool runAll) { auto& profiler = ctx.Get("profiler"); + auto& model = ctx.Get("model"); + /* If the request has a valid size, set the image index. */ + if (imgIndex < NUMBER_OF_FILES) { + if (!SetAppCtxIfmIdx(ctx, imgIndex,"imgIndex")) { + return false; + } + } + auto initialImgIdx = ctx.Get("imgIndex"); constexpr uint32_t dataPsnImgDownscaleFactor = 1; constexpr uint32_t dataPsnImgStartX = 10; @@ -51,31 +47,22 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 150; constexpr uint32_t dataPsnTxtInfStartY = 70; - auto& model = ctx.Get("model"); - - /* If the request has a valid size, set the image index. */ - if (imgIndex < NUMBER_OF_FILES) { - if (!SetAppCtxIfmIdx(ctx, imgIndex,"imgIndex")) { - return false; - } - } if (!model.IsInited()) { printf_err("Model is not initialised! Terminating processing.\n"); return false; } - auto curImIdx = ctx.Get("imgIndex"); - - TfLiteTensor *outputTensor = model.GetOutputTensor(0); - TfLiteTensor *inputTensor = model.GetInputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; - } else if (inputTensor->dims->size < 3) { - printf_err("Input tensor dimension should be >= 3\n"); + } else if (inputTensor->dims->size < 4) { + printf_err("Input tensor dimension should be = 4\n"); return false; } + + /* Get input shape for displaying the image. */ TfLiteIntArray* inputShape = model.GetInputShape(0); const uint32_t nCols = inputShape->data[arm::app::VisualWakeWordModel::ms_inputColsIdx]; const uint32_t nRows = inputShape->data[arm::app::VisualWakeWordModel::ms_inputRowsIdx]; @@ -83,9 +70,19 @@ namespace app { printf_err("Invalid channel index.\n"); return false; } - const uint32_t nChannels = inputShape->data[arm::app::VisualWakeWordModel::ms_inputChannelsIdx]; + + /* We expect RGB images to be provided. */ + const uint32_t displayChannels = 3; + + /* Set up pre and post-processing. */ + VisualWakeWordPreProcess preprocess = VisualWakeWordPreProcess(&model); std::vector results; + VisualWakeWordPostProcess postprocess = VisualWakeWordPostProcess( + ctx.Get("classifier"), &model, + ctx.Get&>("labels"), results); + + UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); do { hal_lcd_clear(COLOR_BLACK); @@ -93,54 +90,55 @@ namespace app { /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; - /* Copy over the data. */ - LoadImageIntoTensor(ctx.Get("imgIndex"), inputTensor); + const uint8_t* imgSrc = get_img_array(ctx.Get("imgIndex")); + if (nullptr == imgSrc) { + printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", ctx.Get("imgIndex"), + NUMBER_OF_FILES - 1); + return false; + } /* Display this image on the LCD. */ hal_lcd_display_image( - static_cast(inputTensor->data.data), - nCols, nRows, nChannels, + imgSrc, + nCols, nRows, displayChannels, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); - /* Vww model preprocessing is image conversion from uint8 to [0,1] float values, - * then quantize them with input quantization info. */ - QuantParams inQuantParams = GetTensorQuantParams(inputTensor); - - auto* req_data = static_cast(inputTensor->data.data); - auto* signed_req_data = static_cast(inputTensor->data.data); - for (size_t i = 0; i < inputTensor->bytes; i++) { - auto i_data_int8 = static_cast(((static_cast(req_data[i]) / 255.0f) / inQuantParams.scale) + inQuantParams.offset); - signed_req_data[i] = std::min(INT8_MAX, std::max(i_data_int8, INT8_MIN)); - } - /* Display message on the LCD - inference running. */ - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); /* Run inference over this image. */ info("Running inference on image %" PRIu32 " => %s\n", ctx.Get("imgIndex"), get_filename(ctx.Get("imgIndex"))); - if (!RunInference(model, profiler)) { + const size_t imgSz = inputTensor->bytes < IMAGE_DATA_SIZE ? + inputTensor->bytes : IMAGE_DATA_SIZE; + + /* Run the pre-processing, inference and post-processing. */ + if (!runner.PreProcess(imgSrc, imgSz)) { + return false; + } + + profiler.StartProfiling("Inference"); + if (!runner.RunInference()) { + return false; + } + profiler.StopProfiling(); + + if (!runner.PostProcess()) { return false; } /* Erase. */ str_inf = std::string(str_inf.size(), ' '); - hal_lcd_display_text( - str_inf.c_str(), str_inf.size(), - dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); - - auto& classifier = ctx.Get("classifier"); - classifier.GetClassificationResults(outputTensor, results, - ctx.Get&>("labels"), 1, - false); + hal_lcd_display_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); /* Add results to context for access outside handler. */ ctx.Set>("results", results); #if VERIFY_TEST_OUTPUT + TfLiteTensor* outputTensor = model.GetOutputTensor(0); arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ @@ -149,43 +147,11 @@ namespace app { } profiler.PrintProfilingResult(); - IncrementAppCtxIfmIdx(ctx,"imgIndex"); - } while (runAll && ctx.Get("imgIndex") != curImIdx); - - return true; - } - - static bool LoadImageIntoTensor(const uint32_t imIdx, - TfLiteTensor *inputTensor) - { - const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ? - inputTensor->bytes : IMAGE_DATA_SIZE; - if (imIdx >= NUMBER_OF_FILES) { - printf_err("invalid image index %" PRIu32 " (max: %u)\n", imIdx, - NUMBER_OF_FILES - 1); - return false; - } + IncrementAppCtxIfmIdx(ctx,"imgIndex"); - if (arm::app::VisualWakeWordModel::ms_inputChannelsIdx >= static_cast(inputTensor->dims->size)) { - printf_err("Invalid channel index.\n"); - return false; - } - const uint32_t nChannels = inputTensor->dims->data[arm::app::VisualWakeWordModel::ms_inputChannelsIdx]; - - const uint8_t* srcPtr = get_img_array(imIdx); - auto* dstPtr = static_cast(inputTensor->data.data); - if (1 == nChannels) { - /** - * Visual Wake Word model accepts only one channel => - * Convert image to grayscale here - **/ - image::RgbToGrayscale(srcPtr, dstPtr, copySz); - } else { - memcpy(inputTensor->data.data, srcPtr, copySz); - } + } while (runAll && ctx.Get("imgIndex") != initialImgIdx); - debug("Image %" PRIu32 " loaded\n", imIdx); return true; } diff --git a/source/use_case/vww/src/VisualWakeWordProcessing.cc b/source/use_case/vww/src/VisualWakeWordProcessing.cc new file mode 100644 index 0000000..94eae28 --- /dev/null +++ b/source/use_case/vww/src/VisualWakeWordProcessing.cc @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "VisualWakeWordProcessing.hpp" +#include "ImageUtils.hpp" +#include "VisualWakeWordModel.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + VisualWakeWordPreProcess::VisualWakeWordPreProcess(Model* model) + { + if (!model->IsInited()) { + printf_err("Model is not initialised!.\n"); + } + this->m_model = model; + } + + bool VisualWakeWordPreProcess::DoPreProcess(const void* data, size_t inputSize) + { + if (data == nullptr) { + printf_err("Data pointer is null"); + } + + auto input = static_cast(data); + TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0); + + auto unsignedDstPtr = static_cast(inputTensor->data.data); + + /* VWW model has one channel input => Convert image to grayscale here. + * We expect images to always be RGB. */ + image::RgbToGrayscale(input, unsignedDstPtr, inputSize); + + /* VWW model pre-processing is image conversion from uint8 to [0,1] float values, + * then quantize them with input quantization info. */ + QuantParams inQuantParams = GetTensorQuantParams(inputTensor); + + auto signedDstPtr = static_cast(inputTensor->data.data); + for (size_t i = 0; i < inputTensor->bytes; i++) { + auto i_data_int8 = static_cast( + ((static_cast(unsignedDstPtr[i]) / 255.0f) / inQuantParams.scale) + inQuantParams.offset + ); + signedDstPtr[i] = std::min(INT8_MAX, std::max(i_data_int8, INT8_MIN)); + } + + debug("Input tensor populated \n"); + + return true; + } + + VisualWakeWordPostProcess::VisualWakeWordPostProcess(Classifier& classifier, Model* model, + const std::vector& labels, std::vector& results) + :m_vwwClassifier{classifier}, + m_labels{labels}, + m_results{results} + { + if (!model->IsInited()) { + printf_err("Model is not initialised!.\n"); + } + this->m_model = model; + } + + bool VisualWakeWordPostProcess::DoPostProcess() + { + return this->m_vwwClassifier.GetClassificationResults( + this->m_model->GetOutputTensor(0), this->m_results, + this->m_labels, 1, true); + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file -- cgit v1.2.1