diff options
5 files changed, 301 insertions, 51 deletions
diff --git a/source/application/main/include/BaseProcessing.hpp b/source/application/main/include/BaseProcessing.hpp new file mode 100644 index 0000000..c1c3255 --- /dev/null +++ b/source/application/main/include/BaseProcessing.hpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BASE_PROCESSING_HPP +#define BASE_PROCESSING_HPP + +#include "Model.hpp" + +namespace arm { +namespace app { + + /** + * @brief Base class exposing pre-processing API. + * Use cases should provide their own PreProcessing class that inherits from this one. + * All steps required to take raw input data and populate tensors ready for inference + * should be handled. + */ + class BasePreProcess { + + public: + virtual ~BasePreProcess() = default; + + /** + * @brief Should perform pre-processing of 'raw' input data and load it into + * TFLite Micro input tensors ready for inference + * @param[in] input Pointer to the data that pre-processing will work on. + * @param[in] inputSize Size of the input data. + * @return true if successful, false otherwise. + **/ + virtual bool DoPreProcess(const void* input, size_t inputSize) = 0; + + protected: + Model* m_model = nullptr; + }; + + /** + * @brief Base class exposing post-processing API. + * Use cases should provide their own PostProcessing class that inherits from this one. + * All steps required to take inference output and populate results vectors should be handled. + */ + class BasePostProcess { + + public: + virtual ~BasePostProcess() = default; + + /** + * @brief Should perform post-processing of the result of inference then populate + * populate result data for any later use. + * @return true if successful, false otherwise. + **/ + virtual bool DoPostProcess() = 0; + + protected: + Model* m_model = nullptr; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* BASE_PROCESSING_HPP */
\ No newline at end of file diff --git a/source/application/main/include/UseCaseCommonUtils.hpp b/source/application/main/include/UseCaseCommonUtils.hpp index 9b6d550..f79f6ed 100644 --- a/source/application/main/include/UseCaseCommonUtils.hpp +++ b/source/application/main/include/UseCaseCommonUtils.hpp @@ -24,6 +24,7 @@ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ #include "Classifier.hpp" /* Classifier. */ #include "InputFiles.hpp" +#include "BaseProcessing.hpp" void DisplayCommonMenu(); @@ -107,6 +108,67 @@ namespace app { **/ bool ListFilesHandler(ApplicationContext& ctx); + /** + * @brief Use case runner class that will handle calling pre-processing, + * inference and post-processing. + * After constructing an instance of this class the user can call + * PreProcess(), RunInference() and PostProcess() to perform inference. + */ + class UseCaseRunner { + + private: + BasePreProcess* m_preProcess; + BasePostProcess* m_postProcess; + Model* m_model; + + public: + explicit UseCaseRunner(BasePreProcess* preprocess, BasePostProcess* postprocess, Model* model) + : m_preProcess{preprocess}, + m_postProcess{postprocess}, + m_model{model} + {}; + + /** + * @brief Runs pre-processing as defined by PreProcess object within the runner. + * Templated for the input data type. + * @param[in] inputData Pointer to the data that inference will be performed on. + * @param[in] inputSize Size of the input data that inference will be performed on. + * @return true if successful, false otherwise. + **/ + template<typename T> + bool PreProcess(T* inputData, size_t inputSize) { + if (!this->m_preProcess->DoPreProcess(inputData, inputSize)) { + printf_err("Pre-processing failed."); + return false; + } + return true; + } + + /** + * @brief Runs inference with the Model object within the runner. + * @return true if successful, false otherwise. + **/ + bool RunInference() { + if (!this->m_model->RunInference()) { + printf_err("Inference failed."); + return false; + } + return true; + } + + /** + * @brief Runs post-processing as defined by PostProcess object within the runner. + * @return true if successful, false otherwise. + **/ + bool PostProcess() { + if (!this->m_postProcess->DoPostProcess()) { + printf_err("Post-processing failed."); + return false; + } + return true; + } + }; + } /* namespace app */ } /* namespace arm */ diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp new file mode 100644 index 0000000..5a59b5f --- /dev/null +++ b/source/use_case/img_class/include/ImgClassProcessing.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef IMG_CLASS_PROCESSING_HPP +#define IMG_CLASS_PROCESSING_HPP + +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "Classifier.hpp" + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for Image Classification use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class ImgClassPreProcess : public BasePreProcess { + + public: + explicit ImgClassPreProcess(Model* model); + + bool DoPreProcess(const void* input, size_t inputSize) override; + }; + + /** + * @brief Post-processing class for Image Classification use case. + * Implements methods declared by BasePostProcess and anything else needed + * to populate result vector. + */ + class ImgClassPostProcess : public BasePostProcess { + + private: + Classifier& m_imgClassifier; + const std::vector<std::string>& m_labels; + std::vector<ClassificationResult>& m_results; + + public: + ImgClassPostProcess(Classifier& classifier, Model* model, + const std::vector<std::string>& labels, + std::vector<ClassificationResult>& results); + + bool DoPostProcess() override; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* IMG_CLASS_PROCESSING_HPP */
\ No newline at end of file diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc new file mode 100644 index 0000000..e33e3c1 --- /dev/null +++ b/source/use_case/img_class/src/ImgClassProcessing.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ImgClassProcessing.hpp" +#include "ImageUtils.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + ImgClassPreProcess::ImgClassPreProcess(Model* model) + { + this->m_model = model; + } + + bool ImgClassPreProcess::DoPreProcess(const void* data, size_t inputSize) + { + if (data == nullptr) { + printf_err("Data pointer is null"); + } + + auto input = static_cast<const uint8_t*>(data); + TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0); + + memcpy(inputTensor->data.data, input, inputSize); + debug("Input tensor populated \n"); + + if (this->m_model->IsDataSigned()) { + image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + } + + return true; + } + + ImgClassPostProcess::ImgClassPostProcess(Classifier& classifier, Model* model, + const std::vector<std::string>& labels, + std::vector<ClassificationResult>& results) + :m_imgClassifier{classifier}, + m_labels{labels}, + m_results{results} + { + this->m_model = model; + } + + bool ImgClassPostProcess::DoPostProcess() + { + return this->m_imgClassifier.GetClassificationResults( + this->m_model->GetOutputTensor(0), this->m_results, + this->m_labels, 5, false); + } + +} /* namespace app */ +} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index 9061282..98e2b59 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -23,6 +23,7 @@ #include "UseCaseCommonUtils.hpp" #include "hal.h" #include "log_macros.h" +#include "ImgClassProcessing.hpp" #include <cinttypes> @@ -31,20 +32,12 @@ using ImgClassClassifier = arm::app::Classifier; namespace arm { namespace app { - /** - * @brief Helper function to load the current image into the input - * tensor. - * @param[in] imIdx Image index (from the pool of images available - * to the application). - * @param[out] inputTensor Pointer to the input tensor to be populated. - * @return true if tensor is loaded, false otherwise. - **/ - static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor); - - /* Image inference classification handler. */ + /* Image classification inference handler. */ bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) { auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& model = ctx.Get<Model&>("model"); + auto initialImIdx = ctx.Get<uint32_t>("imgIndex"); constexpr uint32_t dataPsnImgDownscaleFactor = 2; constexpr uint32_t dataPsnImgStartX = 10; @@ -53,8 +46,6 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 150; constexpr uint32_t dataPsnTxtInfStartY = 40; - auto& model = ctx.Get<Model&>("model"); - /* If the request has a valid size, set the image index. */ if (imgIndex < NUMBER_OF_FILES) { if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) { @@ -66,11 +57,7 @@ namespace app { return false; } - auto curImIdx = ctx.Get<uint32_t>("imgIndex"); - - TfLiteTensor* outputTensor = model.GetOutputTensor(0); TfLiteTensor* inputTensor = model.GetInputTensor(0); - if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; @@ -79,13 +66,20 @@ namespace app { return false; } + /* Get input shape for displaying the image. */ TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx]; const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx]; const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx]; + /* Set up pre and post-processing. */ + ImgClassPreProcess preprocess = ImgClassPreProcess(&model); + std::vector<ClassificationResult> results; + ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model, + ctx.Get<std::vector<std::string>&>("labels"), results); + + UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); do { hal_lcd_clear(COLOR_BLACK); @@ -93,29 +87,42 @@ namespace app { /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; - /* Copy over the data. */ - LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor); + const uint8_t* imgSrc = get_img_array(ctx.Get<uint32_t>("imgIndex")); + if (nullptr == imgSrc) { + printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", ctx.Get<uint32_t>("imgIndex"), + NUMBER_OF_FILES - 1); + return false; + } /* Display this image on the LCD. */ hal_lcd_display_image( - static_cast<uint8_t *>(inputTensor->data.data), + imgSrc, nCols, nRows, nChannels, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); - /* If the data is signed. */ - if (model.IsDataSigned()) { - image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); - } - /* Display message on the LCD - inference running. */ hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - /* Run inference over this image. */ + /* Select the image to run inference with. */ info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"), get_filename(ctx.Get<uint32_t>("imgIndex"))); - if (!RunInference(model, profiler)) { + const size_t imgSz = inputTensor->bytes < IMAGE_DATA_SIZE ? + inputTensor->bytes : IMAGE_DATA_SIZE; + + /* Run the pre-processing, inference and post-processing. */ + if (!runner.PreProcess(imgSrc, imgSz)) { + return false; + } + + profiler.StartProfiling("Inference"); + if (!runner.RunInference()) { + return false; + } + profiler.StopProfiling(); + + if (!runner.PostProcess()) { return false; } @@ -124,15 +131,11 @@ namespace app { hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - auto& classifier = ctx.Get<ImgClassClassifier&>("classifier"); - classifier.GetClassificationResults(outputTensor, results, - ctx.Get<std::vector <std::string>&>("labels"), - 5, false); - /* Add results to context for access outside handler. */ ctx.Set<std::vector<ClassificationResult>>("results", results); #if VERIFY_TEST_OUTPUT + TfLiteTensor* outputTensor = model.GetOutputTensor(0); arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ @@ -144,27 +147,10 @@ namespace app { IncrementAppCtxIfmIdx(ctx,"imgIndex"); - } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx); - - return true; - } + } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImIdx); - static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor) - { - const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ? - inputTensor->bytes : IMAGE_DATA_SIZE; - const uint8_t* imgSrc = get_img_array(imIdx); - if (nullptr == imgSrc) { - printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", imIdx, - NUMBER_OF_FILES - 1); - return false; - } - - memcpy(inputTensor->data.data, imgSrc, copySz); - debug("Image %" PRIu32 " loaded\n", imIdx); return true; } - } /* namespace app */ } /* namespace arm */ |