diff options
Diffstat (limited to 'source/use_case')
-rw-r--r-- | source/use_case/ad/src/UseCaseHandler.cc | 40 | ||||
-rw-r--r-- | source/use_case/asr/src/UseCaseHandler.cc | 41 | ||||
-rw-r--r-- | source/use_case/img_class/src/MainLoop.cc | 37 | ||||
-rw-r--r-- | source/use_case/img_class/src/UseCaseHandler.cc | 116 | ||||
-rw-r--r-- | source/use_case/kws/src/UseCaseHandler.cc | 43 | ||||
-rw-r--r-- | source/use_case/kws_asr/src/UseCaseHandler.cc | 40 | ||||
-rw-r--r-- | source/use_case/vww/include/UseCaseHandler.hpp | 37 | ||||
-rw-r--r-- | source/use_case/vww/include/VisualWakeWordModel.hpp | 48 | ||||
-rw-r--r-- | source/use_case/vww/src/MainLoop.cc | 91 | ||||
-rw-r--r-- | source/use_case/vww/src/UseCaseHandler.cc | 182 | ||||
-rw-r--r-- | source/use_case/vww/src/VisualWakeWordModel.cc | 57 | ||||
-rw-r--r-- | source/use_case/vww/usecase.cmake | 62 |
12 files changed, 498 insertions, 296 deletions
diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc index 0c78179..b20b63e 100644 --- a/source/use_case/ad/src/UseCaseHandler.cc +++ b/source/use_case/ad/src/UseCaseHandler.cc @@ -29,20 +29,6 @@ namespace arm { namespace app { /** - * @brief Helper function to increment current audio clip index - * @param[in,out] ctx pointer to the application context object - **/ - static void IncrementAppCtxClipIdx(ApplicationContext& ctx); - - /** - * @brief Helper function to set the audio clip index - * @param[in,out] ctx pointer to the application context object - * @param[in] idx value to be set - * @return true if index is set, false otherwise - **/ - static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx); - - /** * @brief Presents inference results using the data presentation * object. * @param[in] platform reference to the hal platform object @@ -88,7 +74,7 @@ namespace app { /* If the request has a valid size, set the audio index */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxClipIdx(ctx, clipIndex)) { + if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { return false; } } @@ -225,35 +211,13 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxClipIdx(ctx); + IncrementAppCtxIfmIdx(ctx,"clipIndex"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); return true; } - static void IncrementAppCtxClipIdx(ApplicationContext& ctx) - { - auto curAudioIdx = ctx.Get<uint32_t>("clipIndex"); - - if (curAudioIdx + 1 >= NUMBER_OF_FILES) { - ctx.Set<uint32_t>("clipIndex", 0); - return; - } - ++curAudioIdx; - ctx.Set<uint32_t>("clipIndex", curAudioIdx); - } - - static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx) - { - if (idx >= NUMBER_OF_FILES) { - printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", - idx, NUMBER_OF_FILES); - return false; - } - ctx.Set<uint32_t>("clipIndex", idx); - return true; - } static bool PresentInferenceResult(hal_platform& platform, float result, float threshold) { diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc index 8ef318f..d469255 100644 --- a/source/use_case/asr/src/UseCaseHandler.cc +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -32,20 +32,6 @@ namespace arm { namespace app { /** - * @brief Helper function to increment current audio clip index. - * @param[in,out] ctx Pointer to the application context object. - **/ - static void IncrementAppCtxClipIdx(ApplicationContext& ctx); - - /** - * @brief Helper function to set the audio clip index. - * @param[in,out] ctx Pointer to the application context object. - * @param[in] idx Value to be set. - * @return true if index is set, false otherwise. - **/ - static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx); - - /** * @brief Presents inference results using the data presentation * object. * @param[in] platform Reference to the hal platform object. @@ -69,7 +55,7 @@ namespace app { /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxClipIdx(ctx, clipIndex)) { + if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { return false; } } @@ -214,36 +200,13 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxClipIdx(ctx); + IncrementAppCtxIfmIdx(ctx,"clipIndex"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); return true; } - static void IncrementAppCtxClipIdx(ApplicationContext& ctx) - { - auto curAudioIdx = ctx.Get<uint32_t>("clipIndex"); - - if (curAudioIdx + 1 >= NUMBER_OF_FILES) { - ctx.Set<uint32_t>("clipIndex", 0); - return; - } - ++curAudioIdx; - ctx.Set<uint32_t>("clipIndex", curAudioIdx); - } - - static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx) - { - if (idx >= NUMBER_OF_FILES) { - printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", - idx, NUMBER_OF_FILES); - return false; - } - - ctx.Set<uint32_t>("clipIndex", idx); - return true; - } static bool PresentInferenceResult(hal_platform& platform, const std::vector<arm::app::asr::AsrResult>& results) diff --git a/source/use_case/img_class/src/MainLoop.cc b/source/use_case/img_class/src/MainLoop.cc index 61a09dd..79f6018 100644 --- a/source/use_case/img_class/src/MainLoop.cc +++ b/source/use_case/img_class/src/MainLoop.cc @@ -24,29 +24,6 @@ using ImgClassClassifier = arm::app::Classifier; -enum opcodes -{ - MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ - MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */ - MENU_OPT_RUN_INF_ALL, /* Run inference on all. */ - MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */ - MENU_OPT_LIST_IMAGES /* List the current baked images. */ -}; - -static void DisplayMenu() -{ - printf("\n\n"); - printf("User input required\n"); - printf("Enter option number from:\n\n"); - printf(" %u. Classify next image\n", MENU_OPT_RUN_INF_NEXT); - printf(" %u. Classify image at chosen index\n", MENU_OPT_RUN_INF_CHOSEN); - printf(" %u. Run classification on all images\n", MENU_OPT_RUN_INF_ALL); - printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO); - printf(" %u. List images\n\n", MENU_OPT_LIST_IMAGES); - printf(" Choice: "); - fflush(stdout); -} - void main_loop(hal_platform& platform) { arm::app::MobileNetModel model; /* Model wrapper object. */ @@ -79,29 +56,29 @@ void main_loop(hal_platform& platform) /* Loop. */ do { - int menuOption = MENU_OPT_RUN_INF_NEXT; + int menuOption = common::MENU_OPT_RUN_INF_NEXT; if (bUseMenu) { - DisplayMenu(); + DisplayCommonMenu(); menuOption = arm::app::ReadUserInputAsInt(platform); printf("\n"); } switch (menuOption) { - case MENU_OPT_RUN_INF_NEXT: + case common::MENU_OPT_RUN_INF_NEXT: executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get<uint32_t>("imgIndex"), false); break; - case MENU_OPT_RUN_INF_CHOSEN: { + case common::MENU_OPT_RUN_INF_CHOSEN: { printf(" Enter the image index [0, %d]: ", NUMBER_OF_FILES-1); auto imgIndex = static_cast<uint32_t>(arm::app::ReadUserInputAsInt(platform)); executionSuccessful = ClassifyImageHandler(caseContext, imgIndex, false); break; } - case MENU_OPT_RUN_INF_ALL: + case common::MENU_OPT_RUN_INF_ALL: executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get<uint32_t>("imgIndex"), true); break; - case MENU_OPT_SHOW_MODEL_INFO: + case common::MENU_OPT_SHOW_MODEL_INFO: executionSuccessful = model.ShowModelInfoHandler(); break; - case MENU_OPT_LIST_IMAGES: + case common::MENU_OPT_LIST_IFM: executionSuccessful = ListFilesHandler(caseContext); break; default: diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index 337cb29..66df1da 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -39,37 +39,6 @@ namespace app { **/ static bool LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor); - /** - * @brief Helper function to increment current image index. - * @param[in,out] ctx Pointer to the application context object. - **/ - static void IncrementAppCtxImageIdx(ApplicationContext& ctx); - - /** - * @brief Helper function to set the image index. - * @param[in,out] ctx Pointer to the application context object. - * @param[in] idx Value to be set. - * @return true if index is set, false otherwise. - **/ - static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx); - - /** - * @brief Presents inference results using the data presentation - * object. - * @param[in] platform Reference to the hal platform object. - * @param[in] results Vector of classification results to be displayed. - * @return true if successful, false otherwise. - **/ - static bool PresentInferenceResult(hal_platform& platform, - const std::vector<ClassificationResult>& results); - - /** - * @brief Helper function to convert a UINT8 image to INT8 format. - * @param[in,out] data Pointer to the data start. - * @param[in] kMaxImageSize Total number of pixels in the image. - **/ - static void ConvertImgToInt8(void* data, size_t kMaxImageSize); - /* Image inference classification handler. */ bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) { @@ -89,7 +58,7 @@ namespace app { /* If the request has a valid size, set the image index. */ if (imgIndex < NUMBER_OF_FILES) { - if (!SetAppCtxImageIdx(ctx, imgIndex)) { + if (!SetAppCtxIfmIdx(ctx, imgIndex, "imgIndex")) { return false; } } @@ -134,7 +103,7 @@ namespace app { /* If the data is signed. */ if (model.IsDataSigned()) { - ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); } /* Display message on the LCD - inference running. */ @@ -166,13 +135,13 @@ namespace app { arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ - if (!PresentInferenceResult(platform, results)) { + if (!image::PresentInferenceResult(platform, results)) { return false; } profiler.PrintProfilingResult(); - IncrementAppCtxImageIdx(ctx); + IncrementAppCtxIfmIdx(ctx,"imgIndex"); } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx); @@ -195,83 +164,6 @@ namespace app { return true; } - static void IncrementAppCtxImageIdx(ApplicationContext& ctx) - { - auto curImIdx = ctx.Get<uint32_t>("imgIndex"); - - if (curImIdx + 1 >= NUMBER_OF_FILES) { - ctx.Set<uint32_t>("imgIndex", 0); - return; - } - ++curImIdx; - ctx.Set<uint32_t>("imgIndex", curImIdx); - } - - static bool SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx) - { - if (idx >= NUMBER_OF_FILES) { - printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", - idx, NUMBER_OF_FILES); - return false; - } - ctx.Set<uint32_t>("imgIndex", idx); - return true; - } - - static bool PresentInferenceResult(hal_platform& platform, - const std::vector<ClassificationResult>& results) - { - constexpr uint32_t dataPsnTxtStartX1 = 150; - constexpr uint32_t dataPsnTxtStartY1 = 30; - - constexpr uint32_t dataPsnTxtStartX2 = 10; - constexpr uint32_t dataPsnTxtStartY2 = 150; - - constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ - - platform.data_psn->set_text_color(COLOR_GREEN); - - /* Display each result. */ - uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; - uint32_t rowIdx2 = dataPsnTxtStartY2; - - info("Final results:\n"); - info("Total number of inferences: 1\n"); - for (uint32_t i = 0; i < results.size(); ++i) { - std::string resultStr = - std::to_string(i + 1) + ") " + - std::to_string(results[i].m_labelIdx) + - " (" + std::to_string(results[i].m_normalisedVal) + ")"; - - platform.data_psn->present_data_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX1, rowIdx1, 0); - rowIdx1 += dataPsnTxtYIncr; - - resultStr = std::to_string(i + 1) + ") " + results[i].m_label; - platform.data_psn->present_data_text( - resultStr.c_str(), resultStr.size(), - dataPsnTxtStartX2, rowIdx2, 0); - rowIdx2 += dataPsnTxtYIncr; - - info("%" PRIu32 ") %" PRIu32 " (%f) -> %s\n", i, - results[i].m_labelIdx, results[i].m_normalisedVal, - results[i].m_label.c_str()); - } - - return true; - } - - static void ConvertImgToInt8(void* data, const size_t kMaxImageSize) - { - auto* tmp_req_data = (uint8_t*) data; - auto* tmp_signed_req_data = (int8_t*) data; - - for (size_t i = 0; i < kMaxImageSize; i++) { - tmp_signed_req_data[i] = (int8_t) ( - (int32_t) (tmp_req_data[i]) - 128); - } - } } /* namespace app */ } /* namespace arm */ diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index 2144c03..a951e55 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -33,20 +33,7 @@ using KwsClassifier = arm::app::Classifier; namespace arm { namespace app { - /** - * @brief Helper function to increment current audio clip index. - * @param[in,out] ctx Pointer to the application context object. - **/ - static void IncrementAppCtxClipIdx(ApplicationContext& ctx); - - /** - * @brief Helper function to set the audio clip index. - * @param[in,out] ctx Pointer to the application context object. - * @param[in] idx Value to be set. - * @return true if index is set, false otherwise. - **/ - static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx); - + /** * @brief Presents inference results using the data presentation * object. @@ -94,7 +81,7 @@ namespace app { /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxClipIdx(ctx, clipIndex)) { + if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) { return false; } } @@ -246,36 +233,14 @@ namespace app { profiler.PrintProfilingResult(); - IncrementAppCtxClipIdx(ctx); + IncrementAppCtxIfmIdx(ctx,"clipIndex"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); return true; } - static void IncrementAppCtxClipIdx(ApplicationContext& ctx) - { - auto curAudioIdx = ctx.Get<uint32_t>("clipIndex"); - - if (curAudioIdx + 1 >= NUMBER_OF_FILES) { - ctx.Set<uint32_t>("clipIndex", 0); - return; - } - ++curAudioIdx; - ctx.Set<uint32_t>("clipIndex", curAudioIdx); - } - - static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx) - { - if (idx >= NUMBER_OF_FILES) { - printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", - idx, NUMBER_OF_FILES); - return false; - } - ctx.Set<uint32_t>("clipIndex", idx); - return true; - } - + static bool PresentInferenceResult(hal_platform& platform, const std::vector<arm::app::kws::KwsResult>& results) { diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 9080348..1d88ba1 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -49,20 +49,6 @@ namespace app { }; /** - * @brief Helper function to increment current audio clip index - * @param[in,out] ctx pointer to the application context object - **/ - static void IncrementAppCtxClipIdx(ApplicationContext& ctx); - - /** - * @brief Helper function to set the audio clip index - * @param[in,out] ctx pointer to the application context object - * @param[in] idx value to be set - * @return true if index is set, false otherwise - **/ - static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx); - - /** * @brief Presents kws inference results using the data presentation * object. * @param[in] platform reference to the hal platform object @@ -440,7 +426,7 @@ namespace app { /* If the request has a valid size, set the audio index. */ if (clipIndex < NUMBER_OF_FILES) { - if (!SetAppCtxClipIdx(ctx, clipIndex)) { + if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) { return false; } } @@ -461,35 +447,13 @@ namespace app { } } - IncrementAppCtxClipIdx(ctx); + IncrementAppCtxIfmIdx(ctx,"kws_asr"); } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx); return true; } - static void IncrementAppCtxClipIdx(ApplicationContext& ctx) - { - auto curAudioIdx = ctx.Get<uint32_t>("clipIndex"); - - if (curAudioIdx + 1 >= NUMBER_OF_FILES) { - ctx.Set<uint32_t>("clipIndex", 0); - return; - } - ++curAudioIdx; - ctx.Set<uint32_t>("clipIndex", curAudioIdx); - } - - static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx) - { - if (idx >= NUMBER_OF_FILES) { - printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n", - idx, NUMBER_OF_FILES); - return false; - } - ctx.Set<uint32_t>("clipIndex", idx); - return true; - } static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results) diff --git a/source/use_case/vww/include/UseCaseHandler.hpp b/source/use_case/vww/include/UseCaseHandler.hpp new file mode 100644 index 0000000..7476ed8 --- /dev/null +++ b/source/use_case/vww/include/UseCaseHandler.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef VISUAL_WAKE_WORD_HANDLER_HPP +#define VISUAL_WAKE_WORD_HANDLER_HPP + +#include "AppContext.hpp" + +namespace arm { +namespace app { + + /** + * @brief Handles the inference event. + * @param[in] ctx Pointer to the application context. + * @param[in] imgIndex Index to the image to classify. + * @param[in] runAll Flag to request classification of the available images. + * @return true or false based on execution success. + **/ + bool ClassifyImageHandler(ApplicationContext &ctx, uint32_t imgIndex, bool runAll); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* VISUAL_WAKE_WORD_HANDLER_HPP */ diff --git a/source/use_case/vww/include/VisualWakeWordModel.hpp b/source/use_case/vww/include/VisualWakeWordModel.hpp new file mode 100644 index 0000000..ee3a7bf --- /dev/null +++ b/source/use_case/vww/include/VisualWakeWordModel.hpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef VISUAL_WAKE_WORD_MODEL_HPP +#define VISUAL_WAKE_WORD_MODEL_HPP + +#include "Model.hpp" + +namespace arm { +namespace app { + + class VisualWakeWordModel : public Model { + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int ms_maxOpCnt = 7; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver<ms_maxOpCnt> m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* VISUAL_WAKE_WORD_MODEL_HPP */ diff --git a/source/use_case/vww/src/MainLoop.cc b/source/use_case/vww/src/MainLoop.cc new file mode 100644 index 0000000..f026cc2 --- /dev/null +++ b/source/use_case/vww/src/MainLoop.cc @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "hal.h" /* Brings in platform definitions. */ +#include "Classifier.hpp" /* Classifier. */ +#include "InputFiles.hpp" /* For input images. */ +#include "Labels.hpp" /* For label strings. */ +#include "VisualWakeWordModel.hpp" /* Model class for running inference. */ +#include "UseCaseHandler.hpp" /* Handlers for different user options. */ +#include "UseCaseCommonUtils.hpp" /* Utils functions. */ + +using ViusalWakeWordClassifier = arm::app::Classifier; + +void main_loop(hal_platform &platform) +{ + arm::app::VisualWakeWordModel model; /* Model wrapper object. */ + + /* Load the model. */ + if (!model.Init()) { + printf_err("Failed to initialise model\n"); + return; + } + + /* Instantiate application context. */ + arm::app::ApplicationContext caseContext; + + arm::app::Profiler profiler{&platform, "vww"}; + caseContext.Set<arm::app::Profiler&>("profiler", profiler); + caseContext.Set<hal_platform&>("platform", platform); + caseContext.Set<arm::app::Model&>("model", model); + caseContext.Set<uint32_t>("imgIndex", 0); + + ViusalWakeWordClassifier classifier; /* Classifier wrapper object. */ + caseContext.Set<arm::app::Classifier&>("classifier", classifier); + + std::vector <std::string> labels; + GetLabelsVector(labels); + caseContext.Set<const std::vector <std::string>&>("labels", labels); + + /* Loop. */ + bool executionSuccessful = true; + constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; + do { + int menuOption = common::MENU_OPT_RUN_INF_NEXT; + if (bUseMenu) { + DisplayCommonMenu(); + menuOption = arm::app::ReadUserInputAsInt(platform); + printf("\n"); + } + + switch (menuOption) { + case common::MENU_OPT_RUN_INF_NEXT: + executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get<uint32_t>("imgIndex"), false); + break; + case common::MENU_OPT_RUN_INF_CHOSEN: { + printf(" Enter the image index [0, %d]: ", NUMBER_OF_FILES-1); + auto imgIndex = static_cast<uint32_t>(arm::app::ReadUserInputAsInt(platform)); + executionSuccessful = ClassifyImageHandler(caseContext, imgIndex, false); + break; + } + case common::MENU_OPT_RUN_INF_ALL: + executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get<uint32_t>("imgIndex"), true); + break; + case common::MENU_OPT_SHOW_MODEL_INFO: { + executionSuccessful = model.ShowModelInfoHandler(); + break; + } + case common::MENU_OPT_LIST_IFM: + executionSuccessful = ListFilesHandler(caseContext); + break; + default: + printf("Incorrect choice, try again."); + break; + } + } while (executionSuccessful && bUseMenu); + info("Main loop terminated.\n"); + +} diff --git a/source/use_case/vww/src/UseCaseHandler.cc b/source/use_case/vww/src/UseCaseHandler.cc new file mode 100644 index 0000000..fb2e837 --- /dev/null +++ b/source/use_case/vww/src/UseCaseHandler.cc @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseHandler.hpp" +#include "VisualWakeWordModel.hpp" +#include "Classifier.hpp" +#include "InputFiles.hpp" +#include "UseCaseCommonUtils.hpp" +#include "hal.h" + +namespace arm { +namespace app { + + /** + * @brief Helper function to load the current image into the input + * tensor. + * @param[in] imIdx Image index (from the pool of images available + * to the application). + * @param[out] inputTensor Pointer to the input tensor to be populated. + * @return true if tensor is loaded, false otherwise. + **/ + static bool LoadImageIntoTensor(uint32_t imIdx, + TfLiteTensor *inputTensor); + + /* Image inference classification handler. */ + bool ClassifyImageHandler(ApplicationContext &ctx, uint32_t imgIndex, bool runAll) + { + auto& platform = ctx.Get<hal_platform &>("platform"); + auto& profiler = ctx.Get<Profiler&>("profiler"); + + constexpr uint32_t dataPsnImgDownscaleFactor = 1; + constexpr uint32_t dataPsnImgStartX = 10; + constexpr uint32_t dataPsnImgStartY = 35; + + constexpr uint32_t dataPsnTxtInfStartX = 150; + constexpr uint32_t dataPsnTxtInfStartY = 70; + + + platform.data_psn->clear(COLOR_BLACK); + time_t infTimeMs = 0; + + auto& model = ctx.Get<Model&>("model"); + + /* If the request has a valid size, set the image index. */ + if (imgIndex < NUMBER_OF_FILES) { + if (!SetAppCtxIfmIdx(ctx, imgIndex,"imgIndex")) { + return false; + } + } + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + auto curImIdx = ctx.Get<uint32_t>("imgIndex"); + + TfLiteTensor *outputTensor = model.GetOutputTensor(0); + TfLiteTensor *inputTensor = model.GetInputTensor(0); + + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; + } + TfLiteIntArray* inputShape = model.GetInputShape(0); + const uint32_t nCols = inputShape->data[2]; + const uint32_t nRows = inputShape->data[1]; + const uint32_t nChannels = (inputShape->size == 4) ? inputShape->data[3] : 1; + + std::vector<ClassificationResult> results; + + do { + + /* Strings for presentation/logging. */ + std::string str_inf{"Running inference... "}; + + /* Copy over the data. */ + LoadImageIntoTensor(ctx.Get<uint32_t>("imgIndex"), inputTensor); + + /* Display this image on the LCD. */ + platform.data_psn->present_data_image( + (uint8_t *) inputTensor->data.data, + nCols, nRows, nChannels, + dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); + + /* If the data is signed. */ + if (model.IsDataSigned()) { + image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + } + + /* Display message on the LCD - inference running. */ + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + /* Run inference over this image. */ + info("Running inference on image %" PRIu32 " => %s\n", ctx.Get<uint32_t>("imgIndex"), + get_filename(ctx.Get<uint32_t>("imgIndex"))); + + if (!RunInference(model, profiler)) { + return false; + } + + /* Erase. */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + auto& classifier = ctx.Get<Classifier&>("classifier"); + classifier.GetClassificationResults(outputTensor, results, + ctx.Get<std::vector <std::string>&>("labels"), 1); + + /* Add results to context for access outside handler. */ + ctx.Set<std::vector<ClassificationResult>>("results", results); + +#if VERIFY_TEST_OUTPUT + arm::app::DumpTensor(outputTensor); +#endif /* VERIFY_TEST_OUTPUT */ + + if (!image::PresentInferenceResult(platform, results, infTimeMs)) { + return false; + } + + profiler.PrintProfilingResult(); + IncrementAppCtxIfmIdx(ctx,"imgIndex"); + + } while (runAll && ctx.Get<uint32_t>("imgIndex") != curImIdx); + + return true; + } + + static bool LoadImageIntoTensor(const uint32_t imIdx, + TfLiteTensor *inputTensor) + { + const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ? + inputTensor->bytes : IMAGE_DATA_SIZE; + if (imIdx >= NUMBER_OF_FILES) { + printf_err("invalid image index %" PRIu32 " (max: %u)\n", imIdx, + NUMBER_OF_FILES - 1); + return false; + } + + const uint32_t nChannels = (inputTensor->dims->size == 4) ? inputTensor->dims->data[3] : 1; + + const uint8_t* srcPtr = get_img_array(imIdx); + auto* dstPtr = (uint8_t*)inputTensor->data.data; + if (1 == nChannels) { + /** + * Visual Wake Word model accepts only one channel => + * Convert image to grayscale here + **/ + for (size_t i = 0; i < copySz; ++i, srcPtr += 3) { + *dstPtr++ = 0.2989*(*srcPtr) + + 0.587*(*(srcPtr+1)) + + 0.114*(*(srcPtr+2)); + } + } else { + memcpy(inputTensor->data.data, srcPtr, copySz); + } + + debug("Image %" PRIu32 " loaded\n", imIdx); + return true; + } + +} /* namespace app */ +} /* namespace arm */
\ No newline at end of file diff --git a/source/use_case/vww/src/VisualWakeWordModel.cc b/source/use_case/vww/src/VisualWakeWordModel.cc new file mode 100644 index 0000000..3067c7a --- /dev/null +++ b/source/use_case/vww/src/VisualWakeWordModel.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "VisualWakeWordModel.hpp" + +#include "hal.h" + +const tflite::MicroOpResolver& arm::app::VisualWakeWordModel::GetOpResolver() +{ + return this->m_opResolver; +} + +bool arm::app::VisualWakeWordModel::EnlistOperations() +{ + this->m_opResolver.AddDepthwiseConv2D(); + this->m_opResolver.AddConv2D(); + this->m_opResolver.AddAveragePool2D(); + this->m_opResolver.AddReshape(); + this->m_opResolver.AddPad(); + this->m_opResolver.AddAdd(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +extern uint8_t* GetModelPointer(); +const uint8_t* arm::app::VisualWakeWordModel::ModelPointer() +{ + return GetModelPointer(); +} + +extern size_t GetModelLen(); +size_t arm::app::VisualWakeWordModel::ModelSize() +{ + return GetModelLen(); +}
\ No newline at end of file diff --git a/source/use_case/vww/usecase.cmake b/source/use_case/vww/usecase.cmake new file mode 100644 index 0000000..9a732b7 --- /dev/null +++ b/source/use_case/vww/usecase.cmake @@ -0,0 +1,62 @@ +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +USER_OPTION(${use_case}_FILE_PATH "Directory with custom image files, or path to a single image file, to use in the evaluation application" + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/samples/ + PATH_OR_FILE) + +USER_OPTION(${use_case}_IMAGE_SIZE "Square image size in pixels. Images will be resized to this size." + 128 + STRING) + +USER_OPTION(${use_case}_LABELS_TXT_FILE "Labels' txt file for the chosen model" + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/visual_wake_word_labels.txt + FILEPATH) + +USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen model" + 0x00200000 + STRING) + +if (ETHOS_U55_ENABLED) + set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/vww4_128_128_INT8_vela_H128.tflite) +else() + set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/vww4_128_128_INT8.tflite) +endif() + +USER_OPTION(${use_case}_MODEL_TFLITE_PATH "NN models file to be used in the evaluation application. Model files must be in tflite format." + ${DEFAULT_MODEL_PATH} + FILEPATH) + +# Generate model file +generate_tflite_code( + MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH} + DESTINATION ${SRC_GEN_DIR} +) + +# Generate labels file +set(${use_case}_LABELS_CPP_FILE Labels) +generate_labels_code( + INPUT "${${use_case}_LABELS_TXT_FILE}" + DESTINATION_SRC ${SRC_GEN_DIR} + DESTINATION_HDR ${INC_GEN_DIR} + OUTPUT_FILENAME "${${use_case}_LABELS_CPP_FILE}" +) + +# Generate input files +generate_images_code("${${use_case}_FILE_PATH}" + ${SRC_GEN_DIR} + ${INC_GEN_DIR} + "${${use_case}_IMAGE_SIZE}") |