diff options
author | Éanna Ó Catháin <eanna.ocathain@arm.com> | 2021-09-15 09:32:30 +0100 |
---|---|---|
committer | Kshitij Sisodia <kshitij.sisodia@arm.com> | 2021-09-16 16:01:23 +0100 |
commit | 8f9588721cbb7356b03a714c97d6b3a9a6e89438 (patch) | |
tree | 1ca19d31958081c09f360d91e15fefb6e38b3992 /source/use_case/img_class/src | |
parent | e6588f620c648dd0492f6133152855d77c672568 (diff) | |
download | ml-embedded-evaluation-kit-8f9588721cbb7356b03a714c97d6b3a9a6e89438.tar.gz |
MLECO-2082: Adding visual wake word use case21.08
MLECO-2083: Refactoring img_class and visual wake word
*Added source files for visual wake word
*Added tests
*Added docs
*Added new images for visual wake word demo
*Refactored common functions in img_class, visual wake word and other usecases
Change-Id: Ibd25854e19a5517f940a8d3086a5d4835fab89e9
Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Diffstat (limited to 'source/use_case/img_class/src')
-rw-r--r-- | source/use_case/img_class/src/MainLoop.cc | 37 | ||||
-rw-r--r-- | source/use_case/img_class/src/UseCaseHandler.cc | 116 |
2 files changed, 11 insertions, 142 deletions
diff --git a/source/use_case/img_class/src/MainLoop.cc b/source/use_case/img_class/src/MainLoop.cc index 61a09dd..79f6018 100644 --- a/source/use_case/img_class/src/MainLoop.cc +++ b/source/use_case/img_class/src/MainLoop.cc @@ -24,29 +24,6 @@ using ImgClassClassifier = arm::app::Classifier; -enum opcodes -{ - MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ - MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */ - MENU_OPT_RUN_INF_ALL, /* Run inference on all. */ - MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */ - MENU_OPT_LIST_IMAGES /* List the current baked images. */ -}; - -static void DisplayMenu() -{ - printf("\n\n"); - printf("User input required\n"); - printf("Enter option number from:\n\n"); - printf(" %u. Classify next image\n", MENU_OPT_RUN_INF_NEXT); - printf(" %u. Classify image at chosen index\n", MENU_OPT_RUN_INF_CHOSEN); - printf(" %u. Run classification on all images\n", MENU_OPT_RUN_INF_ALL); - printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO); - printf(" %u. List images\n\n", MENU_OPT_LIST_IMAGES); - printf(" Choice: "); - fflush(stdout); -} - void main_loop(hal_platform& platform) { arm::app::MobileNetModel model; /* Model wrapper object. */ @@ -79,29 +56,29 @@ void main_loop(hal_platform& platform) /* Loop. */ do { - int menuOption = MENU_OPT_RUN_INF_NEXT; + int menuOption = common::MENU_OPT_RUN_INF_NEXT; if (bUseMenu) { - DisplayMenu(); + DisplayCommonMenu(); menuOption = arm::app::ReadUserInputAsInt(platform); printf("\n"); } switch (menuOption) { - case MENU_OPT_RUN_INF_NEXT: + case common::MENU_OPT_RUN_INF_NEXT: executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get<uint32_t>("imgIndex"), false); break; - case MENU_OPT_RUN_INF_CHOSEN: { + case common::MENU_OPT_RUN_INF_CHOSEN: { printf(" Enter the image index [0, %d]: ", NUMBER_OF_FILES-1); auto imgIndex = static_cast<uint32_t>(arm::app::ReadUserInputAsInt(platform)); executionSuccessful = ClassifyImageHandler(caseContext, imgIndex, false); break; } - case MENU_OPT_RUN_INF_ALL: + case common::MENU_OPT_RUN_INF_ALL: executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get<uint32_t>("imgIndex"), true); break; - case MENU_OPT_SHOW_MODEL_INFO: + case common::MENU_OPT_SHOW_MODEL_INFO: executionSuccessful = model.ShowModelInfoHandler(); break; - case MENU_OPT_LIST_IMAGES: + case common::MENU_OPT_LIST_IFM: executionSuccessful = ListFilesHandler(caseContext); break; default: diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index 337cb29..66df1da 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -39,37 +39,6 @@ namespace app { **/ static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor); - /** - * @brief Helper function to increment current image index. - * @param[in,out] ctx Pointer to the application context object. - **/ - static void IncrementAppCtxImageIdx(ApplicationContext& ctx); - - /** - * @brief Helper function to set the image index. - * @param[in,out] ctx Pointer to the application context object. - * @param[in] idx Value to be set. - * @return true if index is set, false otherwise. - **/ - static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx); - - /** - * @brief Presents inference results using the data presentation - * object. - * @param[in] platform Reference to the hal platform object. - * @param[in] results Vector of classification results to be displayed. - * @return true if successful, false otherwise. - **/ - static bool PresentInferenceResult(hal_platform& platform, - const std::vector<ClassificationResult>& results); - - /** - * @brief Helper function to convert a UINT8 image to INT8 format. - * @param[in,out] data Pointer to the data start. - * @param[in] kMaxImageSize Total number of pixels in the image. - **/ - static void ConvertImgToInt8(void* data, size_t kMaxImageSize); - /* Image inference classification handler. */ bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) { @@ -89,7 +58,7 @@ namespace app { /* If the request has a valid size, set the image index. */ if (imgIndex < NUMBER_OF_FILES) { - if (!SetAppCtxImageIdx(ctx, imgIndex)) { + if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) { return false; } } @@ -134,7 +103,7 @@ namespace app { /* If the data is signed. */ if (model.IsDataSigned()) { - ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); } /* Display message on the LCD - inference running. */ @@ -166,13 +135,13 @@ namespace app { arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ - if (!PresentInferenceResult(platform, results)) { + if (!image::PresentInferenceResult(platform, results)) { return false; } profiler.PrintProfilingResult(); - IncrementAppCtxImageIdx(ctx); + IncrementAppCtxIfmIdx(ctx,"imgIndex"); } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx); @@ -195,83 +164,6 @@ namespace app { return true; } - static void IncrementAppCtxImageIdx(ApplicationContext& ctx) - { - auto curImIdx = ctx.Get<uint32_t>("imgIndex"); - - if (curImIdx + 1 >= NUMBER_OF_FILES) { - ctx.Set<uint32_t>("imgIndex", 0); - return; - } - ++curImIdx; - ctx.Set<uint32_t>("imgIndex", curImIdx); - } - - static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx) - { - if (idx >= NUMBER_OF_FILES) { - printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", - idx, NUMBER_OF_FILES); - return false; - } - ctx.Set<uint32_t>("imgIndex", idx); - return true; - } - - static bool PresentInferenceResult(hal_platform& platform, - const std::vector<ClassificationResult>& results) - { - constexpr uint32_t dataPsnTxtStartX1 = 150; - constexpr uint32_t dataPsnTxtStartY1 = 30; - - constexpr uint32_t dataPsnTxtStartX2 = 10; - constexpr uint32_t dataPsnTxtStartY2 = 150; - - constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ - - platform.data_psn->set_text_color(COLOR_GREEN); - - /* Display each result. */ - uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - uint32_t rowIdx2 = dataPsnTxtStartY2; - - info("Final results:\n"); - info("Total number of inferences: 1\n"); - for (uint32_t i = 0; i < results.size(); ++i) { - std::string resultStr = - std::to_string(i + 1) + ") " + - std::to_string(results[i].m_labelIdx) + - " (" + std::to_string(results[i].m_normalisedVal) + ")"; - - platform.data_psn->present_data_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); - rowIdx1 += dataPsnTxtYIncr; - - resultStr = std::to_string(i + 1) + ") " + results[i].m_label; - platform.data_psn->present_data_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX2, rowIdx2, 0); - rowIdx2 += dataPsnTxtYIncr; - - info("%" PRIu32 ") %" PRIu32 " (%f) -> %s\n", i, - results[i].m_labelIdx, results[i].m_normalisedVal, - results[i].m_label.c_str()); - } - - return true; - } - - static void ConvertImgToInt8(void* data, const size_t kMaxImageSize) - { - auto* tmp_req_data = (uint8_t*) data; - auto* tmp_signed_req_data = (int8_t*) data; - - for (size_t i = 0; i < kMaxImageSize; i++) { - tmp_signed_req_data[i] = (int8_t) ( - (int32_t) (tmp_req_data[i]) - 128); - } - } } /* namespace app */ } /* namespace arm */ |