diff options
Diffstat (limited to 'source/use_case')
-rw-r--r-- | source/use_case/img_class/include/ImgClassProcessing.hpp | 63 | ||||
-rw-r--r-- | source/use_case/img_class/src/ImgClassProcessing.cc | 66 | ||||
-rw-r--r-- | source/use_case/img_class/src/UseCaseHandler.cc | 88 |
3 files changed, 166 insertions, 51 deletions
diff --git a/source/use_case/img_class/include/ImgClassProcessing.hpp b/source/use_case/img_class/include/ImgClassProcessing.hpp new file mode 100644 index 0000000..5a59b5f --- /dev/null +++ b/source/use_case/img_class/include/ImgClassProcessing.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef IMG_CLASS_PROCESSING_HPP +#define IMG_CLASS_PROCESSING_HPP + +#include "BaseProcessing.hpp" +#include "Model.hpp" +#include "Classifier.hpp" + +namespace arm { +namespace app { + + /** + * @brief Pre-processing class for Image Classification use case. + * Implements methods declared by BasePreProcess and anything else needed + * to populate input tensors ready for inference. + */ + class ImgClassPreProcess : public BasePreProcess { + + public: + explicit ImgClassPreProcess(Model* model); + + bool DoPreProcess(const void* input, size_t inputSize) override; + }; + + /** + * @brief Post-processing class for Image Classification use case. + * Implements methods declared by BasePostProcess and anything else needed + * to populate result vector. + */ + class ImgClassPostProcess : public BasePostProcess { + + private: + Classifier& m_imgClassifier; + const std::vector<std::string>& m_labels; + std::vector<ClassificationResult>& m_results; + + public: + ImgClassPostProcess(Classifier& classifier, Model* model, + const std::vector<std::string>& labels, + std::vector<ClassificationResult>& results); + + bool DoPostProcess() override; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* IMG_CLASS_PROCESSING_HPP */
\ No newline at end of file diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc new file mode 100644 index 0000000..e33e3c1 --- /dev/null +++ b/source/use_case/img_class/src/ImgClassProcessing.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2022 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ImgClassProcessing.hpp" +#include "ImageUtils.hpp" +#include "log_macros.h" + +namespace arm { +namespace app { + + ImgClassPreProcess::ImgClassPreProcess(Model* model) + { + this->m_model = model; + } + + bool ImgClassPreProcess::DoPreProcess(const void* data, size_t inputSize) + { + if (data == nullptr) { + printf_err("Data pointer is null"); + } + + auto input = static_cast<const uint8_t*>(data); + TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0); + + memcpy(inputTensor->data.data, input, inputSize); + debug("Input tensor populated \n"); + + if (this->m_model->IsDataSigned()) { + image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + } + + return true; + } + + ImgClassPostProcess::ImgClassPostProcess(Classifier& classifier, Model* model, + const std::vector<std::string>& labels, + std::vector<ClassificationResult>& results) + :m_imgClassifier{classifier}, + m_labels{labels}, + m_results{results} + { + this->m_model = model; + } + + bool ImgClassPostProcess::DoPostProcess() + { + return this->m_imgClassifier.GetClassificationResults( + this->m_model->GetOutputTensor(0), this->m_results, + this->m_labels, 5, false); + } + +} /* namespace app */ +} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index 9061282..98e2b59 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -23,6 +23,7 @@ #include "UseCaseCommonUtils.hpp" #include "hal.h" #include "log_macros.h" +#include "ImgClassProcessing.hpp" #include <cinttypes> @@ -31,20 +32,12 @@ using ImgClassClassifier = arm::app::Classifier; namespace arm { namespace app { - /** - * @brief Helper function to load the current image into the input - * tensor. - * @param[in] imIdx Image index (from the pool of images available - * to the application). - * @param[out] inputTensor Pointer to the input tensor to be populated. - * @return true if tensor is loaded, false otherwise. - **/ - static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor); - - /* Image inference classification handler. */ + /* Image classification inference handler. */ bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) { auto& profiler = ctx.Get<Profiler&>("profiler"); + auto& model = ctx.Get<Model&>("model"); + auto initialImIdx = ctx.Get<uint32_t>("imgIndex"); constexpr uint32_t dataPsnImgDownscaleFactor = 2; constexpr uint32_t dataPsnImgStartX = 10; @@ -53,8 +46,6 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 150; constexpr uint32_t dataPsnTxtInfStartY = 40; - auto& model = ctx.Get<Model&>("model"); - /* If the request has a valid size, set the image index. */ if (imgIndex < NUMBER_OF_FILES) { if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) { @@ -66,11 +57,7 @@ namespace app { return false; } - auto curImIdx = ctx.Get<uint32_t>("imgIndex"); - - TfLiteTensor* outputTensor = model.GetOutputTensor(0); TfLiteTensor* inputTensor = model.GetInputTensor(0); - if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; @@ -79,13 +66,20 @@ namespace app { return false; } + /* Get input shape for displaying the image. */ TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx]; const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx]; const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx]; + /* Set up pre and post-processing. */ + ImgClassPreProcess preprocess = ImgClassPreProcess(&model); + std::vector<ClassificationResult> results; + ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model, + ctx.Get<std::vector<std::string>&>("labels"), results); + + UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); do { hal_lcd_clear(COLOR_BLACK); @@ -93,29 +87,42 @@ namespace app { /* Strings for presentation/logging. */ std::string str_inf{"Running inference... "}; - /* Copy over the data. */ - LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor); + const uint8_t* imgSrc = get_img_array(ctx.Get<uint32_t>("imgIndex")); + if (nullptr == imgSrc) { + printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", ctx.Get<uint32_t>("imgIndex"), + NUMBER_OF_FILES - 1); + return false; + } /* Display this image on the LCD. */ hal_lcd_display_image( - static_cast<uint8_t *>(inputTensor->data.data), + imgSrc, nCols, nRows, nChannels, dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); - /* If the data is signed. */ - if (model.IsDataSigned()) { - image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); - } - /* Display message on the LCD - inference running. */ hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - /* Run inference over this image. */ + /* Select the image to run inference with. */ info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"), get_filename(ctx.Get<uint32_t>("imgIndex"))); - if (!RunInference(model, profiler)) { + const size_t imgSz = inputTensor->bytes < IMAGE_DATA_SIZE ? + inputTensor->bytes : IMAGE_DATA_SIZE; + + /* Run the pre-processing, inference and post-processing. */ + if (!runner.PreProcess(imgSrc, imgSz)) { + return false; + } + + profiler.StartProfiling("Inference"); + if (!runner.RunInference()) { + return false; + } + profiler.StopProfiling(); + + if (!runner.PostProcess()) { return false; } @@ -124,15 +131,11 @@ namespace app { hal_lcd_display_text(str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); - auto& classifier = ctx.Get<ImgClassClassifier&>("classifier"); - classifier.GetClassificationResults(outputTensor, results, - ctx.Get<std::vector <std::string>&>("labels"), - 5, false); - /* Add results to context for access outside handler. */ ctx.Set<std::vector<ClassificationResult>>("results", results); #if VERIFY_TEST_OUTPUT + TfLiteTensor* outputTensor = model.GetOutputTensor(0); arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ @@ -144,27 +147,10 @@ namespace app { IncrementAppCtxIfmIdx(ctx,"imgIndex"); - } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx); - - return true; - } + } while (runAll && ctx.Get<uint32_t>("imgIndex") != initialImIdx); - static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor) - { - const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ? - inputTensor->bytes : IMAGE_DATA_SIZE; - const uint8_t* imgSrc = get_img_array(imIdx); - if (nullptr == imgSrc) { - printf_err("Failed to get image index %" PRIu32 " (max: %u)\n", imIdx, - NUMBER_OF_FILES - 1); - return false; - } - - memcpy(inputTensor->data.data, imgSrc, copySz); - debug("Image %" PRIu32 " loaded\n", imIdx); return true; } - } /* namespace app */ } /* namespace arm */ |