From 8f9588721cbb7356b03a714c97d6b3a9a6e89438 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89anna=20=C3=93=20Cath=C3=A1in?= Date: Wed, 15 Sep 2021 09:32:30 +0100 Subject: MLECO-2082: Adding visual wake word use case MLECO-2083: Refactoring img_class and visual wake word MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit *Added source files for visual wake word *Added tests *Added docs *Added new images for visual wake word demo *Refactored common functions in img_class, visual wake word and other usecases Change-Id: Ibd25854e19a5517f940a8d3086a5d4835fab89e9 Signed-off-by: Éanna Ó Catháin --- source/application/main/UseCaseCommonUtils.cc | 265 ++++++++++++++++++++------ 1 file changed, 202 insertions(+), 63 deletions(-) (limited to 'source/application/main/UseCaseCommonUtils.cc') diff --git a/source/application/main/UseCaseCommonUtils.cc b/source/application/main/UseCaseCommonUtils.cc index 615f684..9834475 100644 --- a/source/application/main/UseCaseCommonUtils.cc +++ b/source/application/main/UseCaseCommonUtils.cc @@ -15,91 +15,230 @@ * limitations under the License. */ #include "UseCaseCommonUtils.hpp" - #include "InputFiles.hpp" - #include -namespace arm { -namespace app { - bool RunInference(arm::app::Model& model, Profiler& profiler) - { - profiler.StartProfiling("Inference"); - bool runInf = model.RunInference(); - profiler.StopProfiling(); +void DisplayCommonMenu() +{ + printf("\n\n"); + printf("User input required\n"); + printf("Enter option number from:\n\n"); + printf(" %u. Classify next ifm\n", common::MENU_OPT_RUN_INF_NEXT); + printf(" %u. Classify ifm at chosen index\n", common::MENU_OPT_RUN_INF_CHOSEN); + printf(" %u. Run classification on all ifm\n", common::MENU_OPT_RUN_INF_ALL); + printf(" %u. Show NN model info\n", common::MENU_OPT_SHOW_MODEL_INFO); + printf(" %u. List ifm\n\n", common::MENU_OPT_LIST_IFM); + printf(" Choice: "); + fflush(stdout); +} + +void image::ConvertImgToInt8(void* data, const size_t kMaxImageSize) +{ + auto* tmp_req_data = (uint8_t*) data; + auto* tmp_signed_req_data = (int8_t*) data; - return runInf; + for (size_t i = 0; i < kMaxImageSize; i++) { + tmp_signed_req_data[i] = (int8_t) ( + (int32_t) (tmp_req_data[i]) - 128); } +} - int ReadUserInputAsInt(hal_platform& platform) - { - char chInput[128]; - memset(chInput, 0, sizeof(chInput)); +bool image::PresentInferenceResult(hal_platform& platform, + const std::vector& results) +{ + return PresentInferenceResult(platform, results, false); +} - platform.data_acq->get_input(chInput, sizeof(chInput)); - return atoi(chInput); - } +bool image::PresentInferenceResult(hal_platform &platform, + const std::vector &results, + const time_t infTimeMs) +{ + return PresentInferenceResult(platform, results, true, infTimeMs); +} + + +bool image::PresentInferenceResult(hal_platform &platform, + const std::vector &results, + bool profilingEnabled, + const time_t infTimeMs) +{ + constexpr uint32_t dataPsnTxtStartX1 = 150; + constexpr uint32_t dataPsnTxtStartY1 = 30; - void DumpTensorData(const uint8_t* tensorData, - size_t size, - size_t lineBreakForNumElements) + constexpr uint32_t dataPsnTxtStartX2 = 10; + constexpr uint32_t dataPsnTxtStartY2 = 150; + + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + + if(profilingEnabled) + { + platform.data_psn->set_text_color(COLOR_YELLOW); + + /* If profiling is enabled, and the time is valid. */ + info("Final results:\n"); + info("Total number of inferences: 1\n"); + if (infTimeMs) { - char strhex[8]; - std::string strdump; - - for (size_t i = 0; i < size; ++i) { - if (0 == i % lineBreakForNumElements) { - printf("%s\n\t", strdump.c_str()); - strdump.clear(); - } - snprintf(strhex, sizeof(strhex) - 1, - "0x%02x, ", tensorData[i]); - strdump += std::string(strhex); - } - - if (!strdump.empty()) { - printf("%s\n", strdump.c_str()); - } + std::string strInf = + std::string{"Inference: "} + + std::to_string(infTimeMs) + + std::string{"ms"}; + platform.data_psn->present_data_text( + strInf.c_str(), strInf.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, 0); } + } + platform.data_psn->set_text_color(COLOR_GREEN); + + /* Display each result. */ + uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; + uint32_t rowIdx2 = dataPsnTxtStartY2; - void DumpTensor(const TfLiteTensor* tensor, const size_t lineBreakForNumElements) + if(!profilingEnabled) { - if (!tensor) { - printf_err("invalid tensor\n"); - return; + info("Final results:\n"); + info("Total number of inferences: 1\n"); + } + + for (uint32_t i = 0; i < results.size(); ++i) { + std::string resultStr = + std::to_string(i + 1) + ") " + + std::to_string(results[i].m_labelIdx) + + " (" + std::to_string(results[i].m_normalisedVal) + ")"; + + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); + rowIdx1 += dataPsnTxtYIncr; + + resultStr = std::to_string(i + 1) + ") " + results[i].m_label; + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX2, rowIdx2, 0); + rowIdx2 += dataPsnTxtYIncr; + + if(profilingEnabled) + { + info("%" PRIu32 ") %" PRIu32 " (%f) -> %s\n", i, results[i].m_labelIdx, + results[i].m_normalisedVal, results[i].m_label.c_str()); } + else + { + info("%" PRIu32 ") %" PRIu32 " (%f) -> %s\n", i, + results[i].m_labelIdx, results[i].m_normalisedVal, + results[i].m_label.c_str()); + } + } - const uint32_t tensorSz = tensor->bytes; - const uint8_t* tensorData = tflite::GetTensorData(tensor); + return true; +} - DumpTensorData(tensorData, tensorSz, lineBreakForNumElements); +void IncrementAppCtxIfmIdx(arm::app::ApplicationContext& ctx, std::string useCase) +{ + auto curImIdx = ctx.Get(useCase); + + if (curImIdx + 1 >= NUMBER_OF_FILES) { + ctx.Set(useCase, 0); + return; } + ++curImIdx; + ctx.Set(useCase, curImIdx); +} - bool ListFilesHandler(ApplicationContext& ctx) - { - auto& model = ctx.Get("model"); - auto& platform = ctx.Get("platform"); +bool SetAppCtxIfmIdx(arm::app::ApplicationContext& ctx, uint32_t idx, std::string ctxIfmName) +{ + if (idx >= NUMBER_OF_FILES) { + printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", + idx, NUMBER_OF_FILES); + return false; + } + ctx.Set(ctxIfmName, idx); + return true; +} + + +namespace arm { +namespace app { + + +bool RunInference(arm::app::Model& model, Profiler& profiler) +{ + profiler.StartProfiling("Inference"); + bool runInf = model.RunInference(); + profiler.StopProfiling(); + + return runInf; +} + +int ReadUserInputAsInt(hal_platform& platform) +{ + char chInput[128]; + memset(chInput, 0, sizeof(chInput)); + + platform.data_acq->get_input(chInput, sizeof(chInput)); + return atoi(chInput); +} - constexpr uint32_t dataPsnTxtStartX = 20; - constexpr uint32_t dataPsnTxtStartY = 40; +void DumpTensorData(const uint8_t* tensorData, + size_t size, + size_t lineBreakForNumElements) +{ + char strhex[8]; + std::string strdump; - if (!model.IsInited()) { - printf_err("Model is not initialised! Terminating processing.\n"); - return false; + for (size_t i = 0; i < size; ++i) { + if (0 == i % lineBreakForNumElements) { + printf("%s\n\t", strdump.c_str()); + strdump.clear(); } + snprintf(strhex, sizeof(strhex) - 1, + "0x%02x, ", tensorData[i]); + strdump += std::string(strhex); + } + + if (!strdump.empty()) { + printf("%s\n", strdump.c_str()); + } +} + +void DumpTensor(const TfLiteTensor* tensor, const size_t lineBreakForNumElements) +{ + if (!tensor) { + printf_err("invalid tensor\n"); + return; + } + + const uint32_t tensorSz = tensor->bytes; + const uint8_t* tensorData = tflite::GetTensorData(tensor); - /* Clear the LCD */ - platform.data_psn->clear(COLOR_BLACK); + DumpTensorData(tensorData, tensorSz, lineBreakForNumElements); +} - /* Show the total number of embedded files. */ - std::string strNumFiles = std::string{"Total Number of Files: "} + - std::to_string(NUMBER_OF_FILES); - platform.data_psn->present_data_text(strNumFiles.c_str(), - strNumFiles.size(), - dataPsnTxtStartX, - dataPsnTxtStartY, - false); +bool ListFilesHandler(ApplicationContext& ctx) +{ + auto& model = ctx.Get("model"); + auto& platform = ctx.Get("platform"); + + constexpr uint32_t dataPsnTxtStartX = 20; + constexpr uint32_t dataPsnTxtStartY = 40; + + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + /* Clear the LCD */ + platform.data_psn->clear(COLOR_BLACK); + + /* Show the total number of embedded files. */ + std::string strNumFiles = std::string{"Total Number of Files: "} + + std::to_string(NUMBER_OF_FILES); + platform.data_psn->present_data_text(strNumFiles.c_str(), + strNumFiles.size(), + dataPsnTxtStartX, + dataPsnTxtStartY, + false); #if NUMBER_OF_FILES > 0 constexpr uint32_t dataPsnTxtYIncr = 16; @@ -117,7 +256,7 @@ namespace app { #endif /* NUMBER_OF_FILES > 0 */ return true; - } +} } /* namespace app */ } /* namespace arm */ \ No newline at end of file -- cgit v1.2.1