diff options
author | alexander <alexander.efremov@arm.com> | 2021-03-26 21:42:19 +0000 |
---|---|---|
committer | Kshitij Sisodia <kshitij.sisodia@arm.com> | 2021-03-29 16:29:55 +0100 |
commit | 3c79893217bc632c9b0efa815091bef3c779490c (patch) | |
tree | ad06b444557eb8124652b45621d736fa1b92f65d /source/use_case/img_class/src | |
parent | 6ad6d55715928de72979b04194da1bdf04a4c51b (diff) | |
download | ml-embedded-evaluation-kit-3c79893217bc632c9b0efa815091bef3c779490c.tar.gz |
Opensource ML embedded evaluation kit21.03
Change-Id: I12e807f19f5cacad7cef82572b6dd48252fd61fd
Diffstat (limited to 'source/use_case/img_class/src')
-rw-r--r-- | source/use_case/img_class/src/MainLoop.cc | 109 | ||||
-rw-r--r-- | source/use_case/img_class/src/MobileNetModel.cc | 57 | ||||
-rw-r--r-- | source/use_case/img_class/src/UseCaseHandler.cc | 269 |
3 files changed, 435 insertions, 0 deletions
diff --git a/source/use_case/img_class/src/MainLoop.cc b/source/use_case/img_class/src/MainLoop.cc new file mode 100644 index 0000000..469907c --- /dev/null +++ b/source/use_case/img_class/src/MainLoop.cc @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "hal.h" /* Brings in platform definitions. */ +#include "Classifier.hpp" /* Classifier. */ +#include "InputFiles.hpp" /* For input images. */ +#include "Labels.hpp" /* For label strings. */ +#include "MobileNetModel.hpp" /* Model class for running inference. */ +#include "UseCaseHandler.hpp" /* Handlers for different user options. */ +#include "UseCaseCommonUtils.hpp" /* Utils functions. */ + +using ImgClassClassifier = arm::app::Classifier; + +enum opcodes +{ + MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ + MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */ + MENU_OPT_RUN_INF_ALL, /* Run inference on all. */ + MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */ + MENU_OPT_LIST_IMAGES /* List the current baked images. */ +}; + +static void DisplayMenu() +{ + printf("\n\nUser input required\n"); + printf("Enter option number from:\n\n"); + printf(" %u. Classify next image\n", MENU_OPT_RUN_INF_NEXT); + printf(" %u. Classify image at chosen index\n", MENU_OPT_RUN_INF_CHOSEN); + printf(" %u. Run classification on all images\n", MENU_OPT_RUN_INF_ALL); + printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO); + printf(" %u. List images\n\n", MENU_OPT_LIST_IMAGES); + printf(" Choice: "); +} + +void main_loop(hal_platform& platform) +{ + arm::app::MobileNetModel model; /* Model wrapper object. */ + + /* Load the model. */ + if (!model.Init()) { + printf_err("Failed to initialise model\n"); + return; + } + + /* Instantiate application context. */ + arm::app::ApplicationContext caseContext; + + caseContext.Set<hal_platform&>("platform", platform); + caseContext.Set<arm::app::Model&>("model", model); + caseContext.Set<uint32_t>("imgIndex", 0); + + ImgClassClassifier classifier; /* Classifier wrapper object. */ + caseContext.Set<arm::app::Classifier&>("classifier", classifier); + + std::vector <std::string> labels; + GetLabelsVector(labels); + caseContext.Set<const std::vector <std::string>&>("labels", labels); + + /* Loop. */ + bool executionSuccessful = true; + constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; + + /* Loop. */ + do { + int menuOption = MENU_OPT_RUN_INF_NEXT; + if (bUseMenu) { + DisplayMenu(); + menuOption = arm::app::ReadUserInputAsInt(platform); + printf("\n"); + } + switch (menuOption) { + case MENU_OPT_RUN_INF_NEXT: + executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get<uint32_t>("imgIndex"), false); + break; + case MENU_OPT_RUN_INF_CHOSEN: { + printf(" Enter the image index [0, %d]: ", NUMBER_OF_FILES-1); + auto imgIndex = static_cast<uint32_t>(arm::app::ReadUserInputAsInt(platform)); + executionSuccessful = ClassifyImageHandler(caseContext, imgIndex, false); + break; + } + case MENU_OPT_RUN_INF_ALL: + executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get<uint32_t>("imgIndex"), true); + break; + case MENU_OPT_SHOW_MODEL_INFO: + executionSuccessful = model.ShowModelInfoHandler(); + break; + case MENU_OPT_LIST_IMAGES: + executionSuccessful = ListFilesHandler(caseContext); + break; + default: + printf("Incorrect choice, try again."); + break; + } + } while (executionSuccessful && bUseMenu); + info("Main loop terminated.\n"); +}
\ No newline at end of file diff --git a/source/use_case/img_class/src/MobileNetModel.cc b/source/use_case/img_class/src/MobileNetModel.cc new file mode 100644 index 0000000..eeaa109 --- /dev/null +++ b/source/use_case/img_class/src/MobileNetModel.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MobileNetModel.hpp" + +#include "hal.h" + +const tflite::MicroOpResolver& arm::app::MobileNetModel::GetOpResolver() +{ + return this->_m_opResolver; +} + +bool arm::app::MobileNetModel::EnlistOperations() +{ + this->_m_opResolver.AddDepthwiseConv2D(); + this->_m_opResolver.AddConv2D(); + this->_m_opResolver.AddAveragePool2D(); + this->_m_opResolver.AddAdd(); + this->_m_opResolver.AddReshape(); + this->_m_opResolver.AddSoftmax(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->_m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +extern uint8_t* GetModelPointer(); +const uint8_t* arm::app::MobileNetModel::ModelPointer() +{ + return GetModelPointer(); +} + +extern size_t GetModelLen(); +size_t arm::app::MobileNetModel::ModelSize() +{ + return GetModelLen(); +}
\ No newline at end of file diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc new file mode 100644 index 0000000..a412fec --- /dev/null +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseHandler.hpp" + +#include "Classifier.hpp" +#include "InputFiles.hpp" +#include "MobileNetModel.hpp" +#include "UseCaseCommonUtils.hpp" +#include "hal.h" + +using ImgClassClassifier = arm::app::Classifier; + +namespace arm { +namespace app { + + /** + * @brief Helper function to load the current image into the input + * tensor. + * @param[in] imIdx Image index (from the pool of images available + * to the application). + * @param[out] inputTensor Pointer to the input tensor to be populated. + * @return true if tensor is loaded, false otherwise. + **/ + static bool _LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor); + + /** + * @brief Helper function to increment current image index. + * @param[in,out] ctx Pointer to the application context object. + **/ + static void _IncrementAppCtxImageIdx(ApplicationContext& ctx); + + /** + * @brief Helper function to set the image index. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] idx Value to be set. + * @return true if index is set, false otherwise. + **/ + static bool _SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx); + + /** + * @brief Presents inference results using the data presentation + * object. + * @param[in] platform Reference to the hal platform object. + * @param[in] results Vector of classification results to be displayed. + * @param[in] infTimeMs Inference time in milliseconds, if available + * otherwise, this can be passed in as 0. + * @return true if successful, false otherwise. + **/ + static bool _PresentInferenceResult(hal_platform& platform, + const std::vector<ClassificationResult>& results); + + /** + * @brief Helper function to convert a UINT8 image to INT8 format. + * @param[in,out] data Pointer to the data start. + * @param[in] kMaxImageSize Total number of pixels in the image. + **/ + static void ConvertImgToInt8(void* data, size_t kMaxImageSize); + + /* Image inference classification handler. */ + bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) + { + auto& platform = ctx.Get<hal_platform&>("platform"); + + constexpr uint32_t dataPsnImgDownscaleFactor = 2; + constexpr uint32_t dataPsnImgStartX = 10; + constexpr uint32_t dataPsnImgStartY = 35; + + constexpr uint32_t dataPsnTxtInfStartX = 150; + constexpr uint32_t dataPsnTxtInfStartY = 40; + + platform.data_psn->clear(COLOR_BLACK); + + auto& model = ctx.Get<Model&>("model"); + + /* If the request has a valid size, set the image index. */ + if (imgIndex < NUMBER_OF_FILES) { + if (!_SetAppCtxImageIdx(ctx, imgIndex)) { + return false; + } + } + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + auto curImIdx = ctx.Get<uint32_t>("imgIndex"); + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); + + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; + } + + TfLiteIntArray* inputShape = model.GetInputShape(0); + + const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx]; + const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx]; + const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx]; + + std::vector<ClassificationResult> results; + + do { + /* Strings for presentation/logging. */ + std::string str_inf{"Running inference... "}; + + /* Copy over the data. */ + _LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor); + + /* Display this image on the LCD. */ + platform.data_psn->present_data_image( + (uint8_t*) inputTensor->data.data, + nCols, nRows, nChannels, + dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); + + /* If the data is signed. */ + if (model.IsDataSigned()) { + ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + } + + /* Display message on the LCD - inference running. */ + platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + /* Run inference over this image. */ + info("Running inference on image %u => %s\n", ctx.Get<uint32_t>("imgIndex"), + get_filename(ctx.Get<uint32_t>("imgIndex"))); + + RunInference(platform, model); + + /* Erase. */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + auto& classifier = ctx.Get<ImgClassClassifier&>("classifier"); + classifier.GetClassificationResults(outputTensor, results, + ctx.Get<std::vector <std::string>&>("labels"), + 5); + + /* Add results to context for access outside handler. */ + ctx.Set<std::vector<ClassificationResult>>("results", results); + +#if VERIFY_TEST_OUTPUT + arm::app::DumpTensor(outputTensor); +#endif /* VERIFY_TEST_OUTPUT */ + + if (!_PresentInferenceResult(platform, results)) { + return false; + } + + _IncrementAppCtxImageIdx(ctx); + + } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx); + + return true; + } + + static bool _LoadImageIntoTensor(const uint32_t imIdx, TfLiteTensor* inputTensor) + { + const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ? + inputTensor->bytes : IMAGE_DATA_SIZE; + const uint8_t* imgSrc = get_img_array(imIdx); + if (nullptr == imgSrc) { + printf_err("Failed to get image index %u (max: %u)\n", imIdx, + NUMBER_OF_FILES - 1); + return false; + } + + memcpy(inputTensor->data.data, imgSrc, copySz); + debug("Image %u loaded\n", imIdx); + return true; + } + + static void _IncrementAppCtxImageIdx(ApplicationContext& ctx) + { + auto curImIdx = ctx.Get<uint32_t>("imgIndex"); + + if (curImIdx + 1 >= NUMBER_OF_FILES) { + ctx.Set<uint32_t>("imgIndex", 0); + return; + } + ++curImIdx; + ctx.Set<uint32_t>("imgIndex", curImIdx); + } + + static bool _SetAppCtxImageIdx(ApplicationContext& ctx, const uint32_t idx) + { + if (idx >= NUMBER_OF_FILES) { + printf_err("Invalid idx %u (expected less than %u)\n", + idx, NUMBER_OF_FILES); + return false; + } + ctx.Set<uint32_t>("imgIndex", idx); + return true; + } + + static bool _PresentInferenceResult(hal_platform& platform, + const std::vector<ClassificationResult>& results) + { + constexpr uint32_t dataPsnTxtStartX1 = 150; + constexpr uint32_t dataPsnTxtStartY1 = 30; + + constexpr uint32_t dataPsnTxtStartX2 = 10; + constexpr uint32_t dataPsnTxtStartY2 = 150; + + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + + platform.data_psn->set_text_color(COLOR_GREEN); + + /* Display each result. */ + uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; + uint32_t rowIdx2 = dataPsnTxtStartY2; + + for (uint32_t i = 0; i < results.size(); ++i) { + std::string resultStr = + std::to_string(i + 1) + ") " + + std::to_string(results[i].m_labelIdx) + + " (" + std::to_string(results[i].m_normalisedVal) + ")"; + + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); + rowIdx1 += dataPsnTxtYIncr; + + resultStr = std::to_string(i + 1) + ") " + results[i].m_label; + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX2, rowIdx2, 0); + rowIdx2 += dataPsnTxtYIncr; + + info("%u) %u (%f) -> %s\n", i, results[i].m_labelIdx, + results[i].m_normalisedVal, results[i].m_label.c_str()); + } + + return true; + } + + static void ConvertImgToInt8(void* data, const size_t kMaxImageSize) + { + auto* tmp_req_data = (uint8_t*) data; + auto* tmp_signed_req_data = (int8_t*) data; + + for (size_t i = 0; i < kMaxImageSize; i++) { + tmp_signed_req_data[i] = (int8_t) ( + (int32_t) (tmp_req_data[i]) - 128); + } + } + +} /* namespace app */ +} /* namespace arm */ |