diff options
Diffstat (limited to 'source/use_case/img_class/src/UseCaseHandler.cc')
-rw-r--r-- | source/use_case/img_class/src/UseCaseHandler.cc | 116 |
1 files changed, 4 insertions, 112 deletions
diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index 337cb29..66df1da 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -39,37 +39,6 @@ namespace app { **/ static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor); - /** - * @brief Helper function to increment current image index. - * @param[in,out] ctx Pointer to the application context object. - **/ - static void IncrementAppCtxImageIdx(ApplicationContext& ctx); - - /** - * @brief Helper function to set the image index. - * @param[in,out] ctx Pointer to the application context object. - * @param[in] idx Value to be set. - * @return true if index is set, false otherwise. - **/ - static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx); - - /** - * @brief Presents inference results using the data presentation - * object. - * @param[in] platform Reference to the hal platform object. - * @param[in] results Vector of classification results to be displayed. - * @return true if successful, false otherwise. - **/ - static bool PresentInferenceResult(hal_platform& platform, - const std::vector<ClassificationResult>& results); - - /** - * @brief Helper function to convert a UINT8 image to INT8 format. - * @param[in,out] data Pointer to the data start. - * @param[in] kMaxImageSize Total number of pixels in the image. - **/ - static void ConvertImgToInt8(void* data, size_t kMaxImageSize); - /* Image inference classification handler. */ bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) { @@ -89,7 +58,7 @@ namespace app { /* If the request has a valid size, set the image index. */ if (imgIndex < NUMBER_OF_FILES) { - if (!SetAppCtxImageIdx(ctx, imgIndex)) { + if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) { return false; } } @@ -134,7 +103,7 @@ namespace app { /* If the data is signed. */ if (model.IsDataSigned()) { - ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); } /* Display message on the LCD - inference running. */ @@ -166,13 +135,13 @@ namespace app { arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ - if (!PresentInferenceResult(platform, results)) { + if (!image::PresentInferenceResult(platform, results)) { return false; } profiler.PrintProfilingResult(); - IncrementAppCtxImageIdx(ctx); + IncrementAppCtxIfmIdx(ctx,"imgIndex"); } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx); @@ -195,83 +164,6 @@ namespace app { return true; } - static void IncrementAppCtxImageIdx(ApplicationContext& ctx) - { - auto curImIdx = ctx.Get<uint32_t>("imgIndex"); - - if (curImIdx + 1 >= NUMBER_OF_FILES) { - ctx.Set<uint32_t>("imgIndex", 0); - return; - } - ++curImIdx; - ctx.Set<uint32_t>("imgIndex", curImIdx); - } - - static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx) - { - if (idx >= NUMBER_OF_FILES) { - printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", - idx, NUMBER_OF_FILES); - return false; - } - ctx.Set<uint32_t>("imgIndex", idx); - return true; - } - - static bool PresentInferenceResult(hal_platform& platform, - const std::vector<ClassificationResult>& results) - { - constexpr uint32_t dataPsnTxtStartX1 = 150; - constexpr uint32_t dataPsnTxtStartY1 = 30; - - constexpr uint32_t dataPsnTxtStartX2 = 10; - constexpr uint32_t dataPsnTxtStartY2 = 150; - - constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ - - platform.data_psn->set_text_color(COLOR_GREEN); - - /* Display each result. */ - uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - uint32_t rowIdx2 = dataPsnTxtStartY2; - - info("Final results:\n"); - info("Total number of inferences: 1\n"); - for (uint32_t i = 0; i < results.size(); ++i) { - std::string resultStr = - std::to_string(i + 1) + ") " + - std::to_string(results[i].m_labelIdx) + - " (" + std::to_string(results[i].m_normalisedVal) + ")"; - - platform.data_psn->present_data_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); - rowIdx1 += dataPsnTxtYIncr; - - resultStr = std::to_string(i + 1) + ") " + results[i].m_label; - platform.data_psn->present_data_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX2, rowIdx2, 0); - rowIdx2 += dataPsnTxtYIncr; - - info("%" PRIu32 ") %" PRIu32 " (%f) -> %s\n", i, - results[i].m_labelIdx, results[i].m_normalisedVal, - results[i].m_label.c_str()); - } - - return true; - } - - static void ConvertImgToInt8(void* data, const size_t kMaxImageSize) - { - auto* tmp_req_data = (uint8_t*) data; - auto* tmp_signed_req_data = (int8_t*) data; - - for (size_t i = 0; i < kMaxImageSize; i++) { - tmp_signed_req_data[i] = (int8_t) ( - (int32_t) (tmp_req_data[i]) - 128); - } - } } /* namespace app */ } /* namespace arm */ |