From 3c79893217bc632c9b0efa815091bef3c779490c Mon Sep 17 00:00:00 2001 From: alexander Date: Fri, 26 Mar 2021 21:42:19 +0000 Subject: Opensource ML embedded evaluation kit Change-Id: I12e807f19f5cacad7cef82572b6dd48252fd61fd --- source/application/hal/hal.c | 264 ++++++++ source/application/hal/include/data_acq.h | 52 ++ source/application/hal/include/data_psn.h | 72 +++ source/application/hal/include/hal.h | 81 +++ source/application/hal/include/hal_config.h | 43 ++ source/application/hal/include/timer.h | 80 +++ .../bare-metal/bsp/bsp-core/include/bsp_core_log.h | 66 ++ .../bare-metal/bsp/bsp-core/include/uart_stdout.h | 57 ++ .../platforms/bare-metal/bsp/bsp-core/retarget.c | 235 +++++++ .../bare-metal/bsp/bsp-packs/mps3/device_mps3.c | 36 ++ .../bare-metal/bsp/bsp-packs/mps3/glcd_mps3.c | 460 ++++++++++++++ .../bsp/bsp-packs/mps3/include/device_mps3.h | 111 ++++ .../bsp/bsp-packs/mps3/include/font_9x15_h.h | 128 ++++ .../bsp/bsp-packs/mps3/include/glcd_mps3.h | 202 ++++++ .../bsp/bsp-packs/mps3/include/smm_mps3.h | 615 ++++++++++++++++++ .../bsp/bsp-packs/mps3/include/timer_mps3.h | 86 +++ .../bare-metal/bsp/bsp-packs/mps3/timer_mps3.c | 112 ++++ .../bare-metal/bsp/bsp-packs/mps3/uart_stdout.c | 132 ++++ .../bsp-packs/simple_platform/include/stubs_fvp.h | 124 ++++ .../bsp-packs/simple_platform/include/timer_fvp.h | 55 ++ .../bsp/bsp-packs/simple_platform/stubs_fvp.c | 111 ++++ .../bsp/bsp-packs/simple_platform/timer_fvp.c | 56 ++ .../bsp/bsp-packs/simple_platform/uart_pl011.c | 224 +++++++ .../platforms/bare-metal/bsp/cmsis-device/cmsis.c | 122 ++++ .../bare-metal/bsp/cmsis-device/include/cmsis.h | 31 + .../bare-metal/bsp/cmsis-device/include/irqs.h | 54 ++ .../platforms/bare-metal/bsp/cmsis-device/irqs.c | 261 ++++++++ .../hal/platforms/bare-metal/bsp/include/bsp.h | 38 ++ .../bare-metal/bsp/mem_layout/mps3-sse-200.sct | 102 +++ .../bare-metal/bsp/mem_layout/mps3-sse-300.sct | 118 ++++ .../bare-metal/bsp/mem_layout/simple_platform.sct | 102 +++ .../bare-metal/data_acquisition/data_acq.c | 61 ++ .../bare-metal/data_presentation/data_psn.c | 46 ++ .../data_presentation/lcd/include/lcd_img.h | 90 +++ .../bare-metal/data_presentation/lcd/lcd_img.c | 159 +++++ .../platforms/bare-metal/timer/baremetal_timer.c | 243 +++++++ .../bare-metal/timer/include/baremetal_timer.h | 41 ++ .../bare-metal/utils/include/system_init.h | 43 ++ .../hal/platforms/bare-metal/utils/system_init.c | 118 ++++ .../platforms/native/data_acquisition/data_acq.c | 61 ++ .../platforms/native/data_presentation/data_psn.c | 45 ++ .../native/data_presentation/log/include/log.h | 86 +++ .../platforms/native/data_presentation/log/log.c | 71 +++ .../platforms/native/timer/include/native_timer.h | 31 + .../hal/platforms/native/timer/native_timer.cc | 110 ++++ .../hal/platforms/native/utils/include/dummy_log.h | 64 ++ .../platforms/native/utils/include/system_init.h | 39 ++ .../hal/platforms/native/utils/system_init.c | 32 + source/application/main/Classifier.cc | 191 ++++++ source/application/main/Main.cc | 70 ++ source/application/main/Mfcc.cc | 354 +++++++++++ source/application/main/PlatformMath.cc | 196 ++++++ source/application/main/Profiler.cc | 219 +++++++ source/application/main/UseCaseCommonUtils.cc | 119 ++++ source/application/main/include/AppContext.hpp | 102 +++ source/application/main/include/AudioUtils.hpp | 171 +++++ .../main/include/ClassificationResult.hpp | 41 ++ source/application/main/include/Classifier.hpp | 74 +++ source/application/main/include/DataStructures.hpp | 132 ++++ source/application/main/include/Mfcc.hpp | 255 ++++++++ source/application/main/include/PlatformMath.hpp | 151 +++++ source/application/main/include/Profiler.hpp | 110 ++++ .../main/include/UseCaseCommonUtils.hpp | 76 +++ source/application/tensorflow-lite-micro/Model.cc | 332 ++++++++++ .../tensorflow-lite-micro/TensorFlowLiteMicro.cc | 47 ++ .../include/BufAttributes.hpp | 85 +++ .../tensorflow-lite-micro/include/Model.hpp | 142 +++++ .../include/TensorFlowLiteMicro.hpp | 78 +++ source/use_case/ad/include/AdMelSpectrogram.hpp | 97 +++ source/use_case/ad/include/AdModel.hpp | 53 ++ source/use_case/ad/include/AdPostProcessing.hpp | 50 ++ source/use_case/ad/include/MelSpectrogram.hpp | 233 +++++++ source/use_case/ad/include/UseCaseHandler.hpp | 33 + source/use_case/ad/src/AdMelSpectrogram.cc | 90 +++ source/use_case/ad/src/AdModel.cc | 55 ++ source/use_case/ad/src/AdPostProcessing.cc | 116 ++++ source/use_case/ad/src/MainLoop.cc | 114 ++++ source/use_case/ad/src/MelSpectrogram.cc | 311 +++++++++ source/use_case/ad/src/UseCaseHandler.cc | 422 ++++++++++++ source/use_case/ad/usecase.cmake | 111 ++++ source/use_case/asr/include/AsrClassifier.hpp | 62 ++ source/use_case/asr/include/AsrResult.hpp | 63 ++ source/use_case/asr/include/OutputDecode.hpp | 40 ++ source/use_case/asr/include/UseCaseHandler.hpp | 37 ++ source/use_case/asr/include/Wav2LetterMfcc.hpp | 109 ++++ source/use_case/asr/include/Wav2LetterModel.hpp | 61 ++ .../use_case/asr/include/Wav2LetterPostprocess.hpp | 109 ++++ .../use_case/asr/include/Wav2LetterPreprocess.hpp | 203 ++++++ source/use_case/asr/src/AsrClassifier.cc | 130 ++++ source/use_case/asr/src/MainLoop.cc | 230 +++++++ source/use_case/asr/src/OutputDecode.cc | 47 ++ source/use_case/asr/src/UseCaseHandler.cc | 288 +++++++++ source/use_case/asr/src/Wav2LetterMfcc.cc | 137 ++++ source/use_case/asr/src/Wav2LetterModel.cc | 56 ++ source/use_case/asr/src/Wav2LetterPostprocess.cc | 172 +++++ source/use_case/asr/src/Wav2LetterPreprocess.cc | 228 +++++++ source/use_case/asr/usecase.cmake | 164 +++++ .../use_case/img_class/include/MobileNetModel.hpp | 55 ++ .../use_case/img_class/include/UseCaseHandler.hpp | 37 ++ source/use_case/img_class/src/MainLoop.cc | 109 ++++ source/use_case/img_class/src/MobileNetModel.cc | 57 ++ source/use_case/img_class/src/UseCaseHandler.cc | 269 ++++++++ source/use_case/img_class/usecase.cmake | 125 ++++ .../inference_runner/include/TestModel.hpp | 47 ++ .../inference_runner/include/UseCaseHandler.hpp | 35 + source/use_case/inference_runner/src/MainLoop.cc | 51 ++ source/use_case/inference_runner/src/TestModel.cc | 36 ++ .../inference_runner/src/UseCaseHandler.cc | 88 +++ source/use_case/inference_runner/usecase.cmake | 57 ++ source/use_case/kws/include/DsCnnMfcc.hpp | 50 ++ source/use_case/kws/include/DsCnnModel.hpp | 59 ++ source/use_case/kws/include/KwsResult.hpp | 63 ++ source/use_case/kws/include/UseCaseHandler.hpp | 37 ++ source/use_case/kws/src/DsCnnModel.cc | 58 ++ source/use_case/kws/src/MainLoop.cc | 112 ++++ source/use_case/kws/src/UseCaseHandler.cc | 452 +++++++++++++ source/use_case/kws/usecase.cmake | 159 +++++ source/use_case/kws_asr/include/AsrClassifier.hpp | 64 ++ source/use_case/kws_asr/include/AsrResult.hpp | 63 ++ source/use_case/kws_asr/include/DsCnnMfcc.hpp | 51 ++ source/use_case/kws_asr/include/DsCnnModel.hpp | 67 ++ source/use_case/kws_asr/include/KwsResult.hpp | 63 ++ source/use_case/kws_asr/include/OutputDecode.hpp | 40 ++ source/use_case/kws_asr/include/UseCaseHandler.hpp | 37 ++ source/use_case/kws_asr/include/Wav2LetterMfcc.hpp | 112 ++++ .../use_case/kws_asr/include/Wav2LetterModel.hpp | 67 ++ .../kws_asr/include/Wav2LetterPostprocess.hpp | 101 +++ .../kws_asr/include/Wav2LetterPreprocess.hpp | 205 ++++++ source/use_case/kws_asr/src/AsrClassifier.cc | 131 ++++ source/use_case/kws_asr/src/DsCnnModel.cc | 67 ++ source/use_case/kws_asr/src/MainLoop.cc | 233 +++++++ source/use_case/kws_asr/src/OutputDecode.cc | 47 ++ source/use_case/kws_asr/src/UseCaseHandler.cc | 707 +++++++++++++++++++++ source/use_case/kws_asr/src/Wav2LetterMfcc.cc | 137 ++++ source/use_case/kws_asr/src/Wav2LetterModel.cc | 62 ++ .../use_case/kws_asr/src/Wav2LetterPostprocess.cc | 155 +++++ .../use_case/kws_asr/src/Wav2LetterPreprocess.cc | 228 +++++++ source/use_case/kws_asr/usecase.cmake | 259 ++++++++ 138 files changed, 17188 insertions(+) create mode 100644 source/application/hal/hal.c create mode 100644 source/application/hal/include/data_acq.h create mode 100644 source/application/hal/include/data_psn.h create mode 100644 source/application/hal/include/hal.h create mode 100644 source/application/hal/include/hal_config.h create mode 100644 source/application/hal/include/timer.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-core/include/bsp_core_log.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-core/include/uart_stdout.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-core/retarget.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/device_mps3.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/glcd_mps3.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/device_mps3.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/font_9x15_h.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/glcd_mps3.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/smm_mps3.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/timer_mps3.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/timer_mps3.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/uart_stdout.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/include/stubs_fvp.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/include/timer_fvp.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/stubs_fvp.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/timer_fvp.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/uart_pl011.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/cmsis-device/cmsis.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/cmsis-device/include/cmsis.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/cmsis-device/include/irqs.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/cmsis-device/irqs.c create mode 100644 source/application/hal/platforms/bare-metal/bsp/include/bsp.h create mode 100644 source/application/hal/platforms/bare-metal/bsp/mem_layout/mps3-sse-200.sct create mode 100644 source/application/hal/platforms/bare-metal/bsp/mem_layout/mps3-sse-300.sct create mode 100644 source/application/hal/platforms/bare-metal/bsp/mem_layout/simple_platform.sct create mode 100644 source/application/hal/platforms/bare-metal/data_acquisition/data_acq.c create mode 100644 source/application/hal/platforms/bare-metal/data_presentation/data_psn.c create mode 100644 source/application/hal/platforms/bare-metal/data_presentation/lcd/include/lcd_img.h create mode 100644 source/application/hal/platforms/bare-metal/data_presentation/lcd/lcd_img.c create mode 100644 source/application/hal/platforms/bare-metal/timer/baremetal_timer.c create mode 100644 source/application/hal/platforms/bare-metal/timer/include/baremetal_timer.h create mode 100644 source/application/hal/platforms/bare-metal/utils/include/system_init.h create mode 100644 source/application/hal/platforms/bare-metal/utils/system_init.c create mode 100644 source/application/hal/platforms/native/data_acquisition/data_acq.c create mode 100644 source/application/hal/platforms/native/data_presentation/data_psn.c create mode 100644 source/application/hal/platforms/native/data_presentation/log/include/log.h create mode 100644 source/application/hal/platforms/native/data_presentation/log/log.c create mode 100644 source/application/hal/platforms/native/timer/include/native_timer.h create mode 100644 source/application/hal/platforms/native/timer/native_timer.cc create mode 100644 source/application/hal/platforms/native/utils/include/dummy_log.h create mode 100644 source/application/hal/platforms/native/utils/include/system_init.h create mode 100644 source/application/hal/platforms/native/utils/system_init.c create mode 100644 source/application/main/Classifier.cc create mode 100644 source/application/main/Main.cc create mode 100644 source/application/main/Mfcc.cc create mode 100644 source/application/main/PlatformMath.cc create mode 100644 source/application/main/Profiler.cc create mode 100644 source/application/main/UseCaseCommonUtils.cc create mode 100644 source/application/main/include/AppContext.hpp create mode 100644 source/application/main/include/AudioUtils.hpp create mode 100644 source/application/main/include/ClassificationResult.hpp create mode 100644 source/application/main/include/Classifier.hpp create mode 100644 source/application/main/include/DataStructures.hpp create mode 100644 source/application/main/include/Mfcc.hpp create mode 100644 source/application/main/include/PlatformMath.hpp create mode 100644 source/application/main/include/Profiler.hpp create mode 100644 source/application/main/include/UseCaseCommonUtils.hpp create mode 100644 source/application/tensorflow-lite-micro/Model.cc create mode 100644 source/application/tensorflow-lite-micro/TensorFlowLiteMicro.cc create mode 100644 source/application/tensorflow-lite-micro/include/BufAttributes.hpp create mode 100644 source/application/tensorflow-lite-micro/include/Model.hpp create mode 100644 source/application/tensorflow-lite-micro/include/TensorFlowLiteMicro.hpp create mode 100644 source/use_case/ad/include/AdMelSpectrogram.hpp create mode 100644 source/use_case/ad/include/AdModel.hpp create mode 100644 source/use_case/ad/include/AdPostProcessing.hpp create mode 100644 source/use_case/ad/include/MelSpectrogram.hpp create mode 100644 source/use_case/ad/include/UseCaseHandler.hpp create mode 100644 source/use_case/ad/src/AdMelSpectrogram.cc create mode 100644 source/use_case/ad/src/AdModel.cc create mode 100644 source/use_case/ad/src/AdPostProcessing.cc create mode 100644 source/use_case/ad/src/MainLoop.cc create mode 100644 source/use_case/ad/src/MelSpectrogram.cc create mode 100644 source/use_case/ad/src/UseCaseHandler.cc create mode 100644 source/use_case/ad/usecase.cmake create mode 100644 source/use_case/asr/include/AsrClassifier.hpp create mode 100644 source/use_case/asr/include/AsrResult.hpp create mode 100644 source/use_case/asr/include/OutputDecode.hpp create mode 100644 source/use_case/asr/include/UseCaseHandler.hpp create mode 100644 source/use_case/asr/include/Wav2LetterMfcc.hpp create mode 100644 source/use_case/asr/include/Wav2LetterModel.hpp create mode 100644 source/use_case/asr/include/Wav2LetterPostprocess.hpp create mode 100644 source/use_case/asr/include/Wav2LetterPreprocess.hpp create mode 100644 source/use_case/asr/src/AsrClassifier.cc create mode 100644 source/use_case/asr/src/MainLoop.cc create mode 100644 source/use_case/asr/src/OutputDecode.cc create mode 100644 source/use_case/asr/src/UseCaseHandler.cc create mode 100644 source/use_case/asr/src/Wav2LetterMfcc.cc create mode 100644 source/use_case/asr/src/Wav2LetterModel.cc create mode 100644 source/use_case/asr/src/Wav2LetterPostprocess.cc create mode 100644 source/use_case/asr/src/Wav2LetterPreprocess.cc create mode 100644 source/use_case/asr/usecase.cmake create mode 100644 source/use_case/img_class/include/MobileNetModel.hpp create mode 100644 source/use_case/img_class/include/UseCaseHandler.hpp create mode 100644 source/use_case/img_class/src/MainLoop.cc create mode 100644 source/use_case/img_class/src/MobileNetModel.cc create mode 100644 source/use_case/img_class/src/UseCaseHandler.cc create mode 100644 source/use_case/img_class/usecase.cmake create mode 100644 source/use_case/inference_runner/include/TestModel.hpp create mode 100644 source/use_case/inference_runner/include/UseCaseHandler.hpp create mode 100644 source/use_case/inference_runner/src/MainLoop.cc create mode 100644 source/use_case/inference_runner/src/TestModel.cc create mode 100644 source/use_case/inference_runner/src/UseCaseHandler.cc create mode 100644 source/use_case/inference_runner/usecase.cmake create mode 100644 source/use_case/kws/include/DsCnnMfcc.hpp create mode 100644 source/use_case/kws/include/DsCnnModel.hpp create mode 100644 source/use_case/kws/include/KwsResult.hpp create mode 100644 source/use_case/kws/include/UseCaseHandler.hpp create mode 100644 source/use_case/kws/src/DsCnnModel.cc create mode 100644 source/use_case/kws/src/MainLoop.cc create mode 100644 source/use_case/kws/src/UseCaseHandler.cc create mode 100644 source/use_case/kws/usecase.cmake create mode 100644 source/use_case/kws_asr/include/AsrClassifier.hpp create mode 100644 source/use_case/kws_asr/include/AsrResult.hpp create mode 100644 source/use_case/kws_asr/include/DsCnnMfcc.hpp create mode 100644 source/use_case/kws_asr/include/DsCnnModel.hpp create mode 100644 source/use_case/kws_asr/include/KwsResult.hpp create mode 100644 source/use_case/kws_asr/include/OutputDecode.hpp create mode 100644 source/use_case/kws_asr/include/UseCaseHandler.hpp create mode 100644 source/use_case/kws_asr/include/Wav2LetterMfcc.hpp create mode 100644 source/use_case/kws_asr/include/Wav2LetterModel.hpp create mode 100644 source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp create mode 100644 source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp create mode 100644 source/use_case/kws_asr/src/AsrClassifier.cc create mode 100644 source/use_case/kws_asr/src/DsCnnModel.cc create mode 100644 source/use_case/kws_asr/src/MainLoop.cc create mode 100644 source/use_case/kws_asr/src/OutputDecode.cc create mode 100644 source/use_case/kws_asr/src/UseCaseHandler.cc create mode 100644 source/use_case/kws_asr/src/Wav2LetterMfcc.cc create mode 100644 source/use_case/kws_asr/src/Wav2LetterModel.cc create mode 100644 source/use_case/kws_asr/src/Wav2LetterPostprocess.cc create mode 100644 source/use_case/kws_asr/src/Wav2LetterPreprocess.cc create mode 100644 source/use_case/kws_asr/usecase.cmake (limited to 'source') diff --git a/source/application/hal/hal.c b/source/application/hal/hal.c new file mode 100644 index 0000000..dbf94ba --- /dev/null +++ b/source/application/hal/hal.c @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "hal.h" /* API */ + +#include "hal_config.h" /* HAL configuration */ +#include "system_init.h" + +#include +#include + +#if defined(ARM_NPU) + +#include "ethosu_driver.h" /* Arm Ethos-U55 driver header */ +#include "timing_adapter.h" /* Arm Ethos-U55 timing adapter driver header */ +#include "timing_adapter_settings.h" /* Arm Ethos-U55 timing adapter settings */ + +/** + * @brief Initialises the Arm Ethos-U55 NPU + * @return 0 if successful, error code otherwise + **/ +static int _arm_npu_init(void); + +#endif /* ARM_NPU */ + +int hal_init(hal_platform* platform, data_acq_module* data_acq, + data_psn_module* data_psn, platform_timer* timer) +{ + assert(platform && data_acq && data_psn); + + platform->data_acq = data_acq; + platform->data_psn = data_psn; + platform->timer = timer; + platform->platform_init = system_init; + platform->platform_release = system_release; + system_name(platform->plat_name, sizeof(platform->plat_name)); + + return 0; +} + +/** + * @brief Local helper function to clean the slate for current platform. + **/ +static void _hal_platform_clear(hal_platform* platform) +{ + assert(platform); + platform->inited = 0; +} + +int hal_platform_init(hal_platform* platform) +{ + int state; + assert(platform && platform->platform_init); + _hal_platform_clear(platform); + + /* Initialise platform */ + if (0 != (state = platform->platform_init())) { + printf_err("failed to initialise platform %s\n", platform->plat_name); + return state; + } + + /* Initialise the data acquisition module */ + if (0 != (state = data_acq_channel_init(platform->data_acq))) { + if (!platform->data_acq->inited) { + printf_err("failed to initialise data acq module: %s\n", + platform->data_acq->system_name); + } + hal_platform_release(platform); + return state; + } + + /* Initialise the presentation module */ + if (0 != (state = data_psn_system_init(platform->data_psn))) { + printf_err("failed to initialise data psn module: %s\n", + platform->data_psn->system_name); + data_acq_channel_release(platform->data_acq); + hal_platform_release(platform); + return state; + } + +#if defined(ARM_NPU) + + /* If Arm Ethos-U55 NPU is to be used, we initialise it here */ + if (0 != (state = _arm_npu_init())) { + return state; + } + +#endif /* ARM_NPU */ + + /* followed by the timer module */ + init_timer(platform->timer); + + info("%s platform initialised\n", platform->plat_name); + debug("using %s module for data acquisition\n", + platform->data_acq->system_name); + debug("using %s module for data presentation\n", + platform->data_psn->system_name); + + platform->inited = !state; + + return state; +} + +void hal_platform_release(hal_platform *platform) +{ + assert(platform && platform->platform_release); + data_acq_channel_release(platform->data_acq); + data_psn_system_release(platform->data_psn); + + _hal_platform_clear(platform); + info("releasing platform %s\n", platform->plat_name); + platform->platform_release(); +} + +#if defined(ARM_NPU) +/** + * @brief Defines the Ethos-U interrupt handler: just a wrapper around the default + * implementation. + **/ +static void _arm_npu_irq_handler(void) +{ + /* Call the default interrupt handler from the NPU driver */ + ethosu_irq_handler(); +} + +/** + * @brief Initialises the NPU IRQ + **/ +static void _arm_npu_irq_init(void) +{ + const IRQn_Type ethosu_irqnum = (IRQn_Type)EthosU_IRQn; + + /* Register the EthosU IRQ handler in our vector table. + * Note, this handler comes from the EthosU driver */ + NVIC_SetVector(ethosu_irqnum, (uint32_t)_arm_npu_irq_handler); + + /* Enable the IRQ */ + NVIC_EnableIRQ(ethosu_irqnum); + + debug("EthosU IRQ#: %u, Handler: 0x%p\n", + ethosu_irqnum, _arm_npu_irq_handler); +} + +static int _arm_npu_timing_adapter_init(void) +{ +#if defined (TA0_BASE) + struct timing_adapter ta_0; + struct timing_adapter_settings ta_0_settings = { + .maxr = TA0_MAXR, + .maxw = TA0_MAXW, + .maxrw = TA0_MAXRW, + .rlatency = TA0_RLATENCY, + .wlatency = TA0_WLATENCY, + .pulse_on = TA0_PULSE_ON, + .pulse_off = TA0_PULSE_OFF, + .bwcap = TA0_BWCAP, + .perfctrl = TA0_PERFCTRL, + .perfcnt = TA0_PERFCNT, + .mode = TA0_MODE, + .maxpending = 0, /* This is a read-only parameter */ + .histbin = TA0_HISTBIN, + .histcnt = TA0_HISTCNT + }; + + if (0 != ta_init(&ta_0, TA0_BASE)) { + printf_err("TA0 initialisation failed\n"); + return 1; + } + + ta_set_all(&ta_0, &ta_0_settings); +#endif /* defined (TA0_BASE) */ + +#if defined (TA1_BASE) + struct timing_adapter ta_1; + struct timing_adapter_settings ta_1_settings = { + .maxr = TA1_MAXR, + .maxw = TA1_MAXW, + .maxrw = TA1_MAXRW, + .rlatency = TA1_RLATENCY, + .wlatency = TA1_WLATENCY, + .pulse_on = TA1_PULSE_ON, + .pulse_off = TA1_PULSE_OFF, + .bwcap = TA1_BWCAP, + .perfctrl = TA1_PERFCTRL, + .perfcnt = TA1_PERFCNT, + .mode = TA1_MODE, + .maxpending = 0, /* This is a read-only parameter */ + .histbin = TA1_HISTBIN, + .histcnt = TA1_HISTCNT + }; + + if (0 != ta_init(&ta_1, TA1_BASE)) { + printf_err("TA1 initialisation failed\n"); + return 1; + } + + ta_set_all(&ta_1, &ta_1_settings); +#endif /* defined (TA1_BASE) */ + + return 0; +} + +static int _arm_npu_init(void) +{ + int err = 0; + + /* If the platform has timing adapter blocks along with Ethos-U55 core + * block, initialise them here. */ + if (0 != (err = _arm_npu_timing_adapter_init())) { + return err; + } + + /* Initialise the IRQ */ + _arm_npu_irq_init(); + + /* Initialise Ethos-U55 device */ + const void * ethosu_base_address = (void *)(SEC_ETHOS_U55_BASE); + + if (0 != (err = ethosu_init_v3( + ethosu_base_address, /* Ethos-U55's base address. */ + NULL, /* Pointer to fast mem area - NULL for U55. */ + 0, /* Fast mem region size. */ + 1, /* Security enable. */ + 1))) { /* Privilege enable. */ + printf_err("failed to initalise Ethos-U55 device\n"); + return err; + } + + info("Ethos-U55 device initialised\n"); + + /* Get Ethos-U55 version */ + struct ethosu_version version; + if (0 != (err = ethosu_get_version(&version))) { + printf_err("failed to fetch Ethos-U55 version info\n"); + return err; + } + + info("Ethos-U55 version info:\n"); + info("\tArch: v%u.%u.%u\n", version.id.arch_major_rev, + version.id.arch_minor_rev, + version.id.arch_patch_rev); + info("\tDriver: v%u.%u.%u\n", version.id.driver_major_rev, + version.id.driver_minor_rev, + version.id.driver_patch_rev); + info("\tMACs/cc: %u\n", (1 << version.cfg.macs_per_cc)); + info("\tCmd stream: v%u\n", version.cfg.cmd_stream_version); + info("\tSHRAM size: %u\n", version.cfg.shram_size); + + return 0; +} +#endif /* ARM_NPU */ diff --git a/source/application/hal/include/data_acq.h b/source/application/hal/include/data_acq.h new file mode 100644 index 0000000..965fbe5 --- /dev/null +++ b/source/application/hal/include/data_acq.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATA_ACQ_H +#define DATA_ACQ_H + +/** + * This file is the top level abstraction for the data acquisition module. + **/ +#include + +/* Structure to encompass the data acquisition module and it's methods. */ +typedef struct data_acquisition_module { + int inited; /**< initialised or not. */ + char system_name[8]; /**< name(s) of the channel in use. */ + int (* system_init)(void); /**< channel initialisation function. */ + + /* Function to go and check if there are any events that require handling. */ + int (* get_input)(char *user_input, int size); +} data_acq_module; + +/** + * @brief Initialise the data acquisition channel: goes and + * sets the required channel up for usage. + * @param[in,out] module Pointer to a pre-allocated data + * acquisition structure object. + * @return 0 if successful, error code otherwise. + **/ +int data_acq_channel_init(data_acq_module *module); + +/** + * @brief Releases the data acquisition channel. + * @param[in,out] module Pointer to a pre-allocated data + * acquisition structure object. + * @return 0 if successful, error code otherwise. + **/ +int data_acq_channel_release(data_acq_module *module); + +#endif /* DATA_ACQ_H */ diff --git a/source/application/hal/include/data_psn.h b/source/application/hal/include/data_psn.h new file mode 100644 index 0000000..8c14c77 --- /dev/null +++ b/source/application/hal/include/data_psn.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATA_PSN_H +#define DATA_PSN_H + +/** + * This file is the top level abstraction for the data presentation module + **/ +#include +#include +#include + +/* Structure to encompass the data presentation module and it's methods */ +typedef struct data_presentation_module { + int inited; /**< initialised or not */ + char system_name[8]; /**< name of the system in use */ + int (* system_init)(void); /**< pointer to init function */ + + /** Pointer to the image presentation function */ + int (* present_data_image)(uint8_t *data, const uint32_t width, + const uint32_t height, const uint32_t channels, + const uint32_t pos_x, const uint32_t pos_y, + const uint32_t downsample_factor); + + /* Pointer to text presentation function */ + int (* present_data_text)(const char *str, const size_t str_sz, + const uint32_t pos_x, const uint32_t pos_y, + const bool allow_multiple_lines); + + /* Pointer to box presentation function */ + int (* present_box)(const uint32_t pos_x, const uint32_t pos_y, + const uint32_t width, const uint32_t height, const uint16_t color); + + /* Pointer to clear presentation function */ + int (* clear)(const uint16_t color); + + /* Pointer to set text color presentation function */ + int (* set_text_color)(const uint16_t color); +} data_psn_module; + + +/** + * @brief Initialises the data presentation system. + * @param[in,out] module Pointer to a pre-allocated data + * presentation structure object. + * @return 0 if successful, error code otherwise. + **/ +int data_psn_system_init(data_psn_module *module); + +/** + * @brief Releases the data presentation system. + * @param[in,out] module Pointer to a pre-allocated data + * presentation structure object. + * @return 0 if successful, error code otherwise. + **/ +int data_psn_system_release(data_psn_module *module); + +#endif /* DATA_PSN_H */ diff --git a/source/application/hal/include/hal.h b/source/application/hal/include/hal.h new file mode 100644 index 0000000..26ba1e3 --- /dev/null +++ b/source/application/hal/include/hal.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef PLATFORM_HAL_H +#define PLATFORM_HAL_H + +/** + * This file should present a C API for the main application logic to use + * and be indifferent to the lower level platform. In addition to this it + * will also need to be aware of the API exposed by data acquisition and + * data presentation modules. + */ +#include "hal_config.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#include "data_acq.h" /* Data acquisition abstraction */ +#include "data_psn.h" /* Data presentation abstraction */ +#include "timer.h" /* Timer/profiler API */ + +/* Structure to define a platform context to be used by the application */ +typedef struct hal_platform_context { + int inited; /**< initialised */ + char plat_name[16]; /**< name of this platform */ + data_acq_module * data_acq; /**< data acquisition module pointer */ + data_psn_module * data_psn; /**< data presentation module pointer */ + platform_timer * timer; /**< timer */ + int (* platform_init)(); /**< pointer to platform initialisation function */ + void (* platform_release)(); /**< pointer to platform release function */ +} hal_platform; + +/** + * @brief Initialise the HAL structure based on compile time config. This + * should be called before any other function in this API. + * @param[in,out] platform Pointer to a pre-allocated platform struct. + * @param[in,out] data_acq Pointer to a pre-allocated data acquisition module. + * @param[in,out] data_psn Pointer to a pre-allocated data presentation module. + * @param[in,out] timer Pointer to a pre-allocated timer module. + * @return 0 if successful, error code otherwise. + **/ +int hal_init(hal_platform *platform, data_acq_module *data_acq, + data_psn_module *data_psn, platform_timer *timer); + + +/** + * @brief Initialise the HAL platform. This will go and initialise all the + * modules on the platform the application requires to run. + * @param[in] platform Pointer to a pre-allocated and initialised + * platform structure. + * @return 0 if successful, error code otherwise. + **/ +int hal_platform_init(hal_platform *platform); + + +/** + * @brief Release the HAL platform. This should release resources acquired. + * @param[in] platform pointer to a pre-allocated and initialised + * platform structure. + **/ +void hal_platform_release(hal_platform *platform); + +#ifdef __cplusplus +} +#endif + +#endif /* PLATFORM_HAL_H */ diff --git a/source/application/hal/include/hal_config.h b/source/application/hal/include/hal_config.h new file mode 100644 index 0000000..55db973 --- /dev/null +++ b/source/application/hal/include/hal_config.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef HAL_CONFIG_H +#define HAL_CONFIG_H + +/* This header provides some basic configuration for HAL */ + +/* Platform definitions for the systems we expect to support */ +#define PLATFORM_CORTEX_M_BAREMETAL 1U +#define PLATFORM_UNKNOWN_LINUX_OS 3U + +/* This should come from compile time definition */ +#ifndef PLATFORM_HAL + #define PLATFORM_HAL PLATFORM_UNKNOWN_LINUX_OS /* Default platform */ +#endif /* PLATFORM_HAL */ + +#if ((PLATFORM_HAL) == PLATFORM_CORTEX_M_BAREMETAL) + #include "bsp.h" +#elif ((PLATFORM_HAL) == PLATFORM_UNKNOWN_LINUX_OS) + #include "dummy_log.h" +#else + #error "Invalid platform!" +#endif /* PLATFORM_HAL==PLATFORM_CORTEX_M_BAREMETAL */ + +#if !defined (DESIGN_NAME) + #define DESIGN_NAME ("N/A") +#endif /* !defined (DESIGN_NAME) */ + +#endif /* HAL_CONFIG_H */ diff --git a/source/application/hal/include/timer.h b/source/application/hal/include/timer.h new file mode 100644 index 0000000..2955b7f --- /dev/null +++ b/source/application/hal/include/timer.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef HAL_TIMER_H +#define HAL_TIMER_H + +#include "hal_config.h" + +#if ((PLATFORM_HAL) == PLATFORM_CORTEX_M_BAREMETAL) +#include "baremetal_timer.h" +#elif ((PLATFORM_HAL) == PLATFORM_UNKNOWN_LINUX_OS) +#include "native_timer.h" +#else +#error "Platform does not support a timer API" +#endif /* PLATFORM_HAL */ + +/** Struct for describing the capabilities available for + * the timer provided by HAL */ +typedef struct _platform_timer_capability { + uint32_t npu_cycles: 1; + uint32_t cpu_cycles: 1; + uint32_t duration_ms: 1; + uint32_t duration_us: 1; +} timer_capability; + +/* Structure to hold a platform specific timer implementation */ +typedef struct _platform_timer { + int inited; /**< initialised or not */ + timer_capability cap; /**< capability of this timer */ + + /* reset the timer */ + void (* reset)(void); + + /* Gets the current time counter. */ + time_counter (* get_time_counter)(void); + + /* Gets the duration in milliseconds. */ + time_t (* get_duration_ms)(time_counter *start, time_counter *end); + + /* Gets duration in microseconds. */ + time_t (* get_duration_us)(time_counter *start, time_counter *end); + + /* Gets difference in CPU cycle counts. */ + uint32_t (* get_cpu_cycle_diff)(time_counter *start, time_counter *end); + + /* Gets the difference in terms of total NPU cycle counts. */ + uint64_t (* get_npu_total_cycle_diff)(time_counter *start, time_counter *end); + + /* Gets the difference in terms of active NPU cycle counts. */ + uint64_t (* get_npu_active_cycle_diff)(time_counter *start, time_counter *end); + + /* Wraps get_time_counter function with additional profiling + * initialisation, if required. */ + time_counter (* start_profiling)(void); + + /* Wraps get_time_counter function along with additional instructions when + * profiling ends, if required. */ + time_counter (* stop_profiling)(void); + +} platform_timer; + +/** + * @brief Initialise the timer available for the platform. + **/ +void init_timer(platform_timer *timer); + +#endif /* HAL_TIMER_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-core/include/bsp_core_log.h b/source/application/hal/platforms/bare-metal/bsp/bsp-core/include/bsp_core_log.h new file mode 100644 index 0000000..f049209 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-core/include/bsp_core_log.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BSP_CORE_LOG_H +#define BSP_CORE_LOG_H + +#include "uart_stdout.h" /* UART for logging */ + +#include + +#define LOG_LEVEL_TRACE 0 +#define LOG_LEVEL_DEBUG 1 +#define LOG_LEVEL_INFO 2 +#define LOG_LEVEL_WARN 3 +#define LOG_LEVEL_ERROR 4 + +#ifndef LOG_LEVEL +#define LOG_LEVEL LOG_LEVEL_INFO +#endif /*LOG_LEVEL*/ + +#if (LOG_LEVEL == LOG_LEVEL_TRACE) + #define trace(...) printf("[TRACE] "); printf(__VA_ARGS__) +#else + #define trace(...) +#endif /* LOG_LEVEL == LOG_LEVEL_TRACE */ + +#if (LOG_LEVEL <= LOG_LEVEL_DEBUG) + #define debug(...) printf("[DEBUG] "); printf(__VA_ARGS__) +#else + #define debug(...) +#endif /* LOG_LEVEL > LOG_LEVEL_TRACE */ + +#if (LOG_LEVEL <= LOG_LEVEL_INFO) + #define info(...) printf("[INFO] "); printf(__VA_ARGS__) +#else + #define info(...) +#endif /* LOG_LEVEL > LOG_LEVEL_DEBUG */ + +#if (LOG_LEVEL <= LOG_LEVEL_WARN) + #define warn(...) printf("[WARN] "); printf(__VA_ARGS__) +#else + #define warn(...) +#endif /* LOG_LEVEL > LOG_LEVEL_INFO */ + +#if (LOG_LEVEL <= LOG_LEVEL_ERROR) + #define printf_err(...) printf("[ERROR] "); printf(__VA_ARGS__) +#else + #define printf_err(...) +#endif /* LOG_LEVEL > LOG_LEVEL_INFO */ + +#define UNUSED(x) ((void)(x)) + +#endif /* BSP_CORE_LOG_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-core/include/uart_stdout.h b/source/application/hal/platforms/bare-metal/bsp/bsp-core/include/uart_stdout.h new file mode 100644 index 0000000..9c5fbcf --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-core/include/uart_stdout.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef UART_STDOUT_H +#define UART_STDOUT_H + +#include + +/** + * @brief Initialised the UART block. + **/ +extern void UartStdOutInit(void); + +/** + * @brief Transmits a character over UART (blocking call). + * @param[in] my_ch Character to be transmitted. + * @return Character transmitted. + **/ +extern unsigned char UartPutc(unsigned char my_ch); + +/** + * @brief Receives a character from the UART block (blocking call). + * @return Character received. + **/ +extern unsigned char UartGetc(void); + +/** + * @brief Reads characters from the UART block until a line feed or + * carriage return terminates the function. NULL character + * also terminates the function, error is returned. + * @param[out] lp Characters read from the UART block. + * @param[in] len Character to be transmitted. + * @return true if successful, false otherwise. + **/ +extern bool GetLine(char *lp, unsigned int len); + +/** + * @brief Terminates UART simulation. This is useful when a Fixed + * Virtual Platform's session needs to be gracefully terminated. + * @param[in] code Terminating code displayed on the UART before the end of the simulation. + **/ +extern void UartEndSimulation(int code); + +#endif /* UART_STDOUT_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-core/retarget.c b/source/application/hal/platforms/bare-metal/bsp/bsp-core/retarget.c new file mode 100644 index 0000000..cf31a53 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-core/retarget.c @@ -0,0 +1,235 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050) + +#include "uart_stdout.h" +#include "bsp_core_log.h" + +#if defined (MPS3_PLATFORM) +#include "smm_mps3.h" +#endif /* MPS3_PLATFORM */ + +#include +#include +#include +#include +#include + + +/* Standard IO device handles. */ +#define STDIN 0x8001 +#define STDOUT 0x8002 +#define STDERR 0x8003 + +/* Standard IO device name defines. */ +const char __stdin_name[] = "STDIN"; +const char __stdout_name[] = "STDOUT"; +const char __stderr_name[] = "STDERR"; + +int fputc(int ch, FILE *f) +{ + UNUSED(f); + return (UartPutc(ch)); +} + +int fgetc(FILE *f) +{ + UNUSED(f); + return (UartPutc(UartGetc())); +} + +int ferror(FILE *f) +{ + UNUSED(f); + /* Your implementation of ferror */ + return EOF; +} + +void _ttywrch(int ch) +{ + UartPutc(ch); +} + +FILEHANDLE _sys_open(const char *name, int openmode) +{ + UNUSED(openmode); + + /* Register standard Input Output devices. */ + if (strcmp(name, "STDIN") == 0) + { + return (STDIN); + } + if (strcmp(name, "STDOUT") == 0) + { + return (STDOUT); + } + if (strcmp(name, "STDERR") == 0) + { + return (STDERR); + } + return (-1); +} + +int _sys_close(FILEHANDLE fh) +{ + if (fh > 0x8000) + { + return (0); + } + return (-1); +} + +int _sys_write(FILEHANDLE fh, const unsigned char *buf, unsigned int len, int mode) +{ + UNUSED(mode); + if (fh == STDOUT || fh == STDERR ) + { + /* Standard Output device. */ + for (; len; len--) + { + UartPutc(*buf++); + } + return (0); + } + + if (fh > 0x8000) + { + return (-1); + } + return (-1); +} + +int _sys_read(FILEHANDLE fh, unsigned char *buf, unsigned int len, int mode) +{ + UNUSED(mode); + if (fh == STDIN) + { + /* Standard Input device. */ + for (; len; len--) + { + *buf++ = UartGetc(); + } + return (0); + } + + if (fh > 0x8000) + { + return (-1); + } + return (-1); +} + +int _sys_istty(FILEHANDLE fh) +{ + if (fh > 0x8000) + { + return (1); + } + return (0); +} + +int _sys_seek(FILEHANDLE fh, long pos) +{ + UNUSED(pos); + if (fh > 0x8000) + { + return (-1); + } + return (-1); +} + +int _sys_ensure(FILEHANDLE fh) +{ + if (fh > 0x8000) + { + return (-1); + } + return (-1); +} + +long _sys_flen(FILEHANDLE fh) +{ + if (fh > 0x8000) + { + return (0); + } + return (-1); +} + +int _sys_tmpnam(char *name, int sig, unsigned maxlen) +{ + UNUSED(name); + UNUSED(sig); + UNUSED(maxlen); + return (1); +} + +char *_sys_command_string(char *cmd, int len) +{ + UNUSED(len); + return (cmd); +} + +void _sys_exit(int return_code) +{ + UartEndSimulation(return_code); +} + +int system(const char *cmd) +{ + UNUSED(cmd); + return (0); +} + +time_t time(time_t *timer) +{ + time_t current; + +#if defined (MPS3_PLATFORM) + current = MPS3_FPGAIO->COUNTER; +#else /* MPS3_PLATFORM */ + current = 0; /* No RTC implementation available. */ +#endif /* MPS3_PLATFORM */ + + if (timer != NULL) { + *timer = current; + } + + return (current); +} + +#else /* #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050) */ + +/******************************************************************************/ +/* Retarget functions for GNU Tools for ARM Embedded Processors */ +/******************************************************************************/ +#include +#include + +extern unsigned char UartPutc(unsigned char my_ch); + +__attribute__((used)) int _write(int fd, char *ptr, int len) +{ + size_t i; + for (i = 0; i < len; i++) + { + UartPutc(ptr[i]); /* call character output function. */ + } + return len; +} + +#endif /* #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050) */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/device_mps3.c b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/device_mps3.c new file mode 100644 index 0000000..f4f2e6b --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/device_mps3.c @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "device_mps3.h" + +#include "bsp_core_log.h" +#include "smm_mps3.h" + +uint32_t GetMPS3CoreClock(void) +{ + const uint32_t default_clock = 32000000; + static int warned_once = 0; + if (0 != MPS3_SCC->CFG_ACLK) { + return MPS3_SCC->CFG_ACLK; + } + + if (!warned_once) { + warn("MPS3_SCC->CFG_ACLK reads 0. Assuming default clock of %u\n", + default_clock); + warned_once = 1; + } + return default_clock; +} diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/glcd_mps3.c b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/glcd_mps3.c new file mode 100644 index 0000000..530be4f --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/glcd_mps3.c @@ -0,0 +1,460 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "glcd_mps3.h" + +#include "bsp_core_log.h" +#include "font_9x15_h.h" +#include "smm_mps3.h" + +/*-------------- CLCD Controller Internal Register addresses ----------------*/ +#define CHAR_COM ((volatile unsigned int *)(CLCD_CONFIG_BASE + 0x000)) +#define CHAR_DAT ((volatile unsigned int *)(CLCD_CONFIG_BASE + 0x004)) +#define CHAR_RD ((volatile unsigned int *)(CLCD_CONFIG_BASE + 0x008)) +#define CHAR_RAW ((volatile unsigned int *)(CLCD_CONFIG_BASE + 0x00C)) +#define CHAR_MASK ((volatile unsigned int *)(CLCD_CONFIG_BASE + 0x010)) +#define CHAR_STAT ((volatile unsigned int *)(CLCD_CONFIG_BASE + 0x014)) +#define CHAR_MISC ((volatile unsigned int *)(CLCD_CONFIG_BASE + 0x04C)) + +/*--------------- Graphic LCD interface hardware definitions -----------------*/ +/* Pin CS setting to 0 or 1 */ +#define LCD_CS(x) ((x) ? (*CHAR_MISC |= CLCD_CS_Msk) : (*CHAR_MISC &= ~CLCD_CS_Msk)) +#define LCD_RST(x) ((x) ? (*CHAR_MISC |= CLCD_RESET_Msk) : (*CHAR_MISC &= ~CLCD_RESET_Msk)) +#define LCD_BL(x) ((x) ? (*CHAR_MISC |= CLCD_BL_Msk) : (*CHAR_MISC &= ~CLCD_BL_Msk)) + +#define BG_COLOR 0 /* Background colour */ +#define TXT_COLOR 1 /* Text colour */ + +/** +* Text and background colour +*/ +static volatile unsigned short Color[2] = {Black, White}; + +/** + * @brief Delay in while loop cycles. + * @param[in] cnt Number of while cycles to delay. + **/ +static void delay (int cnt) +{ + cnt <<= DELAY_2N; + while (cnt != 0) { + --cnt; + } +} + +/** + * @brief Write a command the LCD controller. + * @param[in] cmd Command to be written. + */ +static __inline void wr_cmd(unsigned char cmd) +{ + LCD_CS(0); + *CHAR_COM = cmd; + LCD_CS(1); +} + +/** + * @brief Start of data writing to the LCD controller. + */ +static __inline void wr_dat_start (void) +{ + LCD_CS(0); +} + +/** + * @brief Stop of data writing to the LCD controller. + */ +static __inline void wr_dat_stop (void) +{ + LCD_CS(1); +} + +/** + * @brief Data writing to the LCD controller. + * @param[in] dat Data to be written. + */ +static __inline void wr_dat_only(unsigned short dat) +{ + *CHAR_DAT = (dat >> 8); /* Write D8..D15 */ + *CHAR_DAT = (dat & 0xFF); /* Write D0..D7 */ +} + +/** + * @brief Write a value to the to LCD register. + * @param[in] reg Register to be written. + * @param[in] val Value to write to the register. + */ +static __inline void wr_reg(unsigned char reg, unsigned short val) +{ + LCD_CS(0); + *CHAR_COM = reg; + wr_dat_only(val); + LCD_CS(1); +} + +/** + * @brief Converts a gray value to RGB565 representation. + * @param[in] src_uchar Pointer to the source pixel. + * @return 16 bit RGB565 value. + */ +static inline uint16_t _GLCD_Gray8_to_RGB565(uint8_t *src_uchar) +{ + uint16_t val_r = (*src_uchar >> 3); + uint16_t val_g = (*src_uchar >> 2); + return ((val_r << 11) | (val_g << 5) | val_r); +} + +/** + * @brief Converts an RGB888 value to RGB565 representation. + * @param[in] src_uchar Pointer to the source pixel for R (assumed to + * be RGB format). + * @return 16 bit RGB565 value. + */ +static inline uint16_t _GLCD_RGB888_to_RGB565(uint8_t *src_uchar) +{ + uint16_t val_r = (*src_uchar >> 3) & 0x1F; + uint16_t val_g = (*(src_uchar+1) >> 2) & 0x3F; + uint16_t val_b = (*(src_uchar+2) >> 3) & 0x1F; + return ((val_r << 11) | (val_g << 5) | val_b); +} + +/* Helper typedef to encapsulate the colour conversion function + * signatures */ +typedef uint16_t (* std_clr_2_lcd_clr_fn)(uint8_t *src_uchar); + +void GLCD_SetWindow(unsigned int x, unsigned int y, unsigned int w, unsigned int h) { + unsigned int xe, ye; + + xe = x+w-1; + ye = y+h-1; + + wr_reg(0x02, x >> 8); /* Column address start MSB */ + wr_reg(0x03, x & 0xFF); /* Column address start LSB */ + wr_reg(0x04, xe >> 8); /* Column address end MSB */ + wr_reg(0x05, xe & 0xFF); /* Column address end LSB */ + + wr_reg(0x06, y >> 8); /* Row address start MSB */ + wr_reg(0x07, y & 0xFF); /* Row address start LSB */ + wr_reg(0x08, ye >> 8); /* Row address end MSB */ + wr_reg(0x09, ye & 0xFF); /* Row address end LSB */ +} + +void GLCD_WindowMax(void) +{ + GLCD_SetWindow (0, 0, GLCD_WIDTH, GLCD_HEIGHT); +} + +void GLCD_SetTextColor(unsigned short color) +{ + Color[TXT_COLOR] = color; +} + +void GLCD_SetBackColor(unsigned short color) +{ + Color[BG_COLOR] = color; +} + +void GLCD_Clear(unsigned short color) +{ + unsigned int i; + + GLCD_WindowMax(); + wr_cmd(0x22); + wr_dat_start(); + + for(i = 0; i < (GLCD_WIDTH*GLCD_HEIGHT); ++i) { + wr_dat_only(color); + } + wr_dat_stop(); +} + + +void GLCD_DrawChar( + unsigned int x, unsigned int y, + unsigned int cw, unsigned int ch, + unsigned char *c) +{ + unsigned int i, j, k, pixs; + + /* Sanity check: out of bounds? */ + if ((x + cw) > GLCD_WIDTH || (y + ch) > GLCD_HEIGHT) { + return; + } + + GLCD_SetWindow(x, y, cw, ch); + + wr_cmd(0x22); + wr_dat_start(); + + k = (cw + 7)/8; + + if (k == 1) { + for (j = 0; j < ch; ++j) { + pixs = *(unsigned char *)c; + c += 1; + + for (i = 0; i < cw; ++i) { + wr_dat_only (Color[(pixs >> i) & 1]); + } + } + } + else if (k == 2) { + for (j = 0; j < ch; ++j) { + pixs = *(unsigned short *)c; + c += 2; + + for (i = 0; i < cw; ++i) { + wr_dat_only (Color[(pixs >> i) & 1]); + } + } + } + wr_dat_stop(); +} + +void GLCD_DisplayChar( + unsigned int ln, unsigned int col, + unsigned char fi, unsigned char c) +{ + c -= 32; + switch (fi) { + case 0: /* Font 9 x 15. */ + GLCD_DrawChar(col * 9, ln * 15, 9, 15, + (unsigned char *)&Font_9x15_h[c * 15]); + break; + } +} + +void GLCD_DisplayString( + unsigned int ln, unsigned int col, + unsigned char fi, char *s) +{ + while (*s) { + GLCD_DisplayChar(ln, col++, fi, *s++); + } +} + + + +void GLCD_ClearLn(unsigned int ln, unsigned char fi) +{ + unsigned char i; + char buf[60]; + + GLCD_WindowMax(); + switch (fi) { + case 0: /* Font 9x15*/ + for (i = 0; i < (GLCD_WIDTH+8)/9; ++i) { + buf[i] = ' '; + } + buf[i+1] = 0; + break; + } + GLCD_DisplayString (ln, 0, fi, buf); +} + +void GLCD_Bitmap(unsigned int x, unsigned int y, + unsigned int w, unsigned int h, + unsigned short *bitmap) +{ + unsigned int i; + unsigned short *bitmap_ptr = bitmap; + + GLCD_SetWindow (x, y, w, h); + + wr_cmd(0x22); + wr_dat_start(); + + for (i = 0; i < (w*h); ++i) { + wr_dat_only (bitmap_ptr[i]); + } + wr_dat_stop(); +} + +void GLCD_Image(void *data, const uint32_t width, + const uint32_t height, const uint32_t channels, + const uint32_t pos_x, const uint32_t pos_y, + const uint32_t downsample_factor) +{ + uint32_t i, j = 0; /* for loops */ + const uint32_t x_incr = channels * downsample_factor; /* stride. */ + const uint32_t y_incr = channels * width * (downsample_factor - 1); /* skip rows. */ + uint8_t* src_unsigned = (uint8_t *)data; /* temporary pointer. */ + std_clr_2_lcd_clr_fn cvt_clr_fn = 0; /* colour conversion function. */ + + /* Based on number of channels, we decide which of the above functions to use. */ + switch (channels) { + case 1: + cvt_clr_fn = _GLCD_Gray8_to_RGB565; + break; + + case 3: + cvt_clr_fn = _GLCD_RGB888_to_RGB565; + break; + + default: + printf_err("number of channels not supported by display\n"); + return; + } + + /* Set the window position expected. Note: this is integer div. */ + GLCD_SetWindow(pos_x, pos_y, + width/downsample_factor, height/downsample_factor); + wr_cmd(0x22); + wr_dat_start(); + + /* Loop over the image. */ + for (j = height; j != 0; j -= downsample_factor) { + for (i = width; i != 0; i -= downsample_factor) { + wr_dat_only(cvt_clr_fn(src_unsigned)); + src_unsigned += x_incr; + } + + /* Skip rows if needed. */ + src_unsigned += y_incr; + } + + wr_dat_stop(); +} + +void GLCD_Box( + unsigned int x, unsigned int y, + unsigned int w, unsigned int h, + unsigned short color) +{ + unsigned int i; + + GLCD_SetWindow (x, y, w, h); + + wr_cmd(0x22); + wr_dat_start(); + for(i = 0; i < (w*h); ++i){ + wr_dat_only (color); + } + wr_dat_stop(); +} + + +void GLCD_Initialize (void) +{ + /* CLCD screen setup (Default CLCD screen interface state) ------------- */ + LCD_CS(1); /* deassert nCS0. */ + LCD_RST(1); /* deassert Reset. */ + LCD_BL(0); /* switch off backlight. */ + + /* Reset CLCD screen --------------------------------------------------- */ + LCD_RST(0); /* assert Reset. */ + delay(1); + LCD_RST(1); /* deassert Reset. */ + delay(10); + + /* Driving ability settings ----------------------------------------------*/ + wr_reg(0xEA, 0x00); /* Power control internal used (1). */ + wr_reg(0xEB, 0x20); /* Power control internal used (2). */ + wr_reg(0xEC, 0x0C); /* Source control internal used (1). */ + wr_reg(0xED, 0xC7); /* Source control internal used (2). */ + wr_reg(0xE8, 0x38); /* Source output period Normal mode. */ + wr_reg(0xE9, 0x10); /* Source output period Idle mode. */ + wr_reg(0xF1, 0x01); /* RGB 18-bit interface ;0x0110. */ + wr_reg(0xF2, 0x10); + + /* Adjust the Gamma Curve ------------------------------------------------*/ + wr_reg(0x40, 0x01); + wr_reg(0x41, 0x00); + wr_reg(0x42, 0x00); + wr_reg(0x43, 0x10); + wr_reg(0x44, 0x0E); + wr_reg(0x45, 0x24); + wr_reg(0x46, 0x04); + wr_reg(0x47, 0x50); + wr_reg(0x48, 0x02); + wr_reg(0x49, 0x13); + wr_reg(0x4A, 0x19); + wr_reg(0x4B, 0x19); + wr_reg(0x4C, 0x16); + + wr_reg(0x50, 0x1B); + wr_reg(0x51, 0x31); + wr_reg(0x52, 0x2F); + wr_reg(0x53, 0x3F); + wr_reg(0x54, 0x3F); + wr_reg(0x55, 0x3E); + wr_reg(0x56, 0x2F); + wr_reg(0x57, 0x7B); + wr_reg(0x58, 0x09); + wr_reg(0x59, 0x06); + wr_reg(0x5A, 0x06); + wr_reg(0x5B, 0x0C); + wr_reg(0x5C, 0x1D); + wr_reg(0x5D, 0xCC); + + /* Power voltage setting -------------------------------------------------*/ + wr_reg(0x1B, 0x1B); + wr_reg(0x1A, 0x01); + wr_reg(0x24, 0x2F); + wr_reg(0x25, 0x57); + wr_reg(0x23, 0x88); + + /* Power on setting ------------------------------------------------------*/ + wr_reg(0x18, 0x36); /* Internal oscillator frequency adj. */ + wr_reg(0x19, 0x01); /* Enable internal oscillator. */ + wr_reg(0x01, 0x00); /* Normal mode, no scroll. */ + wr_reg(0x1F, 0x88); /* Power control 6 - DDVDH Off. */ + delay(20); + wr_reg(0x1F, 0x82); /* Power control 6 - Step-up: 3 x VCI. */ + delay(5); + wr_reg(0x1F, 0x92); /* Power control 6 - Step-up: On. */ + delay(5); + wr_reg(0x1F, 0xD2); /* Power control 6 - VCOML active. */ + delay(5); + + /* Color selection -------------------------------------------------------*/ + wr_reg(0x17, 0x55); /* RGB, System interface: 16 Bit/Pixel. */ + wr_reg(0x00, 0x00); /* Scrolling off, no standby. */ + + /* Interface config ------------------------------------------------------*/ + wr_reg(0x2F, 0x11); /* LCD Drive: 1-line inversion. */ + wr_reg(0x31, 0x00); + wr_reg(0x32, 0x00); /* DPL=0, HSPL=0, VSPL=0, EPL=0. */ + + /* Display on setting ----------------------------------------------------*/ + wr_reg(0x28, 0x38); /* PT(0,0) active, VGL/VGL. */ + delay(20); + wr_reg(0x28, 0x3C); /* Display active, VGL/VGL. */ + +#if (LANDSCAPE == 1) +#if (ROTATE180 == 0) + wr_reg (0x16, 0xA8); +#else /* (ROTATE180 == 0) */ + wr_reg (0x16, 0x68); +#endif /* (ROTATE180 == 0) */ +#else /* (LANDSCAPE == 1) */ +#if (ROTATE180 == 0) + wr_reg (0x16, 0x08); +#else /* (ROTATE180 == 0) */ + wr_reg (0x16, 0xC8); +#endif /* (ROTATE180 == 0) */ +#endif /* (LANDSCAPE == 1) */ + + /* Display scrolling settings --------------------------------------------*/ + wr_reg(0x0E, 0x00); /* TFA MSB */ + wr_reg(0x0F, 0x00); /* TFA LSB */ + wr_reg(0x10, 320 >> 8); /* VSA MSB */ + wr_reg(0x11, 320 & 0xFF); /* VSA LSB */ + wr_reg(0x12, 0x00); /* BFA MSB */ + wr_reg(0x13, 0x00); /* BFA LSB */ + + LCD_BL(1); /* turn on backlight */ +} diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/device_mps3.h b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/device_mps3.h new file mode 100644 index 0000000..f0bab79 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/device_mps3.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DEVICE_MPS3_H +#define DEVICE_MPS3_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "cmsis.h" /* CMSIS device header. */ +#include "smm_mps3.h" /* Memory map for MPS3. */ + +#include + +typedef struct _CMSDK_UART_TypeDef_ +{ + __IO uint32_t DATA; /* Offset: 0x000 (R/W) Data Register. */ + __IO uint32_t STATE; /* Offset: 0x004 (R/W) Status Register. */ + __IO uint32_t CTRL; /* Offset: 0x008 (R/W) Control Register. */ + + union { + __I uint32_t INTSTATUS; /* Offset: 0x00C (R/ ) Interrupt Status Register. */ + __O uint32_t INTCLEAR; /* Offset: 0x00C ( /W) Interrupt Clear Register. */ + }; + __IO uint32_t BAUDDIV; /* Offset: 0x010 (R/W) Baudrate Divider Register. */ + +} CMSDK_UART_TypeDef; + +#define CMSDK_UART0 ((CMSDK_UART_TypeDef *)CMSDK_UART0_BASE) + +/* CMSDK_UART DATA Register Definitions. */ +#define CMSDK_UART_DATA_Pos 0 /* CMSDK_UART_DATA_Pos: DATA Position. */ +#define CMSDK_UART_DATA_Msk (0xFFul << CMSDK_UART_DATA_Pos) /* CMSDK_UART DATA: DATA Mask. */ + +/* CMSDK_UART STATE Register Definitions. */ +#define CMSDK_UART_STATE_RXOR_Pos 3 /* CMSDK_UART STATE: RXOR Position. */ +#define CMSDK_UART_STATE_RXOR_Msk (0x1ul << CMSDK_UART_STATE_RXOR_Pos) /* CMSDK_UART STATE: RXOR Mask. */ + +#define CMSDK_UART_STATE_TXOR_Pos 2 /* CMSDK_UART STATE: TXOR Position. */ +#define CMSDK_UART_STATE_TXOR_Msk (0x1ul << CMSDK_UART_STATE_TXOR_Pos) /* CMSDK_UART STATE: TXOR Mask. */ + +#define CMSDK_UART_STATE_RXBF_Pos 1 /* CMSDK_UART STATE: RXBF Position. */ +#define CMSDK_UART_STATE_RXBF_Msk (0x1ul << CMSDK_UART_STATE_RXBF_Pos) /* CMSDK_UART STATE: RXBF Mask. */ + +#define CMSDK_UART_STATE_TXBF_Pos 0 /* CMSDK_UART STATE: TXBF Position. */ +#define CMSDK_UART_STATE_TXBF_Msk (0x1ul << CMSDK_UART_STATE_TXBF_Pos ) /* CMSDK_UART STATE: TXBF Mask. */ + +/* CMSDK_UART CTRL Register Definitions. */ +#define CMSDK_UART_CTRL_HSTM_Pos 6 /* CMSDK_UART CTRL: HSTM Position. */ +#define CMSDK_UART_CTRL_HSTM_Msk (0x01ul << CMSDK_UART_CTRL_HSTM_Pos) /* CMSDK_UART CTRL: HSTM Mask. */ + +#define CMSDK_UART_CTRL_RXORIRQEN_Pos 5 /* CMSDK_UART CTRL: RXORIRQEN Position. */ +#define CMSDK_UART_CTRL_RXORIRQEN_Msk (0x01ul << CMSDK_UART_CTRL_RXORIRQEN_Pos) /* CMSDK_UART CTRL: RXORIRQEN Mask. */ + +#define CMSDK_UART_CTRL_TXORIRQEN_Pos 4 /* CMSDK_UART CTRL: TXORIRQEN Position. */ +#define CMSDK_UART_CTRL_TXORIRQEN_Msk (0x01ul << CMSDK_UART_CTRL_TXORIRQEN_Pos) /* CMSDK_UART CTRL: TXORIRQEN Mask. */ + +#define CMSDK_UART_CTRL_RXIRQEN_Pos 3 /* CMSDK_UART CTRL: RXIRQEN Position. */ +#define CMSDK_UART_CTRL_RXIRQEN_Msk (0x01ul << CMSDK_UART_CTRL_RXIRQEN_Pos) /* CMSDK_UART CTRL: RXIRQEN Mask. */ + +#define CMSDK_UART_CTRL_TXIRQEN_Pos 2 /* CMSDK_UART CTRL: TXIRQEN Position. */ +#define CMSDK_UART_CTRL_TXIRQEN_Msk (0x01ul << CMSDK_UART_CTRL_TXIRQEN_Pos) /* CMSDK_UART CTRL: TXIRQEN Mask. */ + +#define CMSDK_UART_CTRL_RXEN_Pos 1 /* CMSDK_UART CTRL: RXEN Position. */ +#define CMSDK_UART_CTRL_RXEN_Msk (0x01ul << CMSDK_UART_CTRL_RXEN_Pos) /* CMSDK_UART CTRL: RXEN Mask. */ + +#define CMSDK_UART_CTRL_TXEN_Pos 0 /* CMSDK_UART CTRL: TXEN Position. */ +#define CMSDK_UART_CTRL_TXEN_Msk (0x01ul << CMSDK_UART_CTRL_TXEN_Pos) /* CMSDK_UART CTRL: TXEN Mask. */ + +/* CMSDK_UART INTSTATUS\INTCLEAR Register Definitions. */ +#define CMSDK_UART_INT_RXORIRQ_Pos 3 /* CMSDK_UART INT: RXORIRQ Position. */ +#define CMSDK_UART_INT_RXORIRQ_Msk (0x01ul << CMSDK_UART_INT_RXORIRQ_Pos) /* CMSDK_UART INT: RXORIRQ Mask. */ + +#define CMSDK_UART_INT_TXORIRQ_Pos 2 /* CMSDK_UART INT: TXORIRQ Position. */ +#define CMSDK_UART_INT_TXORIRQ_Msk (0x01ul << CMSDK_UART_INT_TXORIRQ_Pos) /* CMSDK_UART INT: TXORIRQ Mask. */ + +#define CMSDK_UART_INT_RXIRQ_Pos 1 /* CMSDK_UART INT: RXIRQ Position. */ +#define CMSDK_UART_INT_RXIRQ_Msk (0x01ul << CMSDK_UART_INT_RXIRQ_Pos) /* CMSDK_UART INT: RXIRQ Mask. */ + +#define CMSDK_UART_INT_TXIRQ_Pos 0 /* CMSDK_UART INT: TXIRQ Position. */ +#define CMSDK_UART_INT_TXIRQ_Msk (0x01ul << CMSDK_UART_INT_TXIRQ_Pos) /* CMSDK_UART INT: TXIRQ Mask. */ + +/* CMSDK_UART BAUDDIV Register Definitions. */ +#define CMSDK_UART_BAUDDIV_Pos 0 /* CMSDK_UART BAUDDIV: BAUDDIV Position. */ +#define CMSDK_UART_BAUDDIV_Msk (0xFFFFFul << CMSDK_UART_BAUDDIV_Pos) + +/** + * @brief Gets the core clock set for MPS3. + * @return Clock value in Hz. + **/ +uint32_t GetMPS3CoreClock(void); + +#ifdef __cplusplus +} +#endif + +#endif /* DEVICE_MPS3_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/font_9x15_h.h b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/font_9x15_h.h new file mode 100644 index 0000000..b8b6bdc --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/font_9x15_h.h @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//Font Generated by MikroElektronika GLCD Font Creator 1.2.0.0 +//MikroElektrnika 2011 +//http://www.mikroe.com + +//GLCD FontName : Lucida_Console9x15 +//GLCD FontSize : 9x15 + +#ifndef FONT_9x15_H_H +#define FONT_9x15_H_H + +const unsigned short Font_9x15_h[] = { + 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, /* Code for char num 32. */ + 0x00,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x00,0x10,0x10,0x00,0x00,0x00, /* Code for char num 33. */ + 0x44,0x44,0x44,0x44,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, /* Code for char num 34. */ + 0x00,0x12,0x12,0x24,0x7F,0x24,0x28,0x48,0xFE,0x48,0x90,0x90,0x00,0x00,0x00, /* Code for char num 35. */ + 0x10,0x7C,0x16,0x12,0x12,0x1C,0x38,0x70,0x50,0x50,0x52,0x3E,0x10,0x00,0x00, /* Code for char num 36. */ + 0x00,0x8C,0x92,0x52,0x52,0x2C,0x10,0x08,0x68,0x94,0x92,0x92,0x62,0x00,0x00, /* Code for char num 37. */ + 0x00,0x18,0x24,0x24,0x34,0x18,0x0C,0x12,0xB2,0xE2,0xC2,0xBC,0x00,0x00,0x00, /* Code for char num 38. */ + 0x08,0x08,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, /* Code for char num 39. */ + 0xC0,0x60,0x10,0x10,0x08,0x08,0x08,0x08,0x08,0x08,0x10,0x10,0x60,0xC0,0x00, /* Code for char num 40. */ + 0x0C,0x18,0x20,0x20,0x40,0x40,0x40,0x40,0x40,0x40,0x20,0x20,0x18,0x0C,0x00, /* Code for char num 41. */ + 0x00,0x10,0x92,0xEE,0x18,0x28,0x28,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, /* Code for char num 42. */ + 0x00,0x00,0x00,0x00,0x10,0x10,0x10,0x10,0xFE,0x10,0x10,0x10,0x00,0x00,0x00, /* Code for char num 43. */ + 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x18,0x18,0x10,0x08,0x00, /* Code for char num 44. */ + 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x7C,0x00,0x00,0x00,0x00,0x00,0x00,0x00, /* Code for char num 45. */ + 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x18,0x18,0x00,0x00,0x00, /* Code for char num 46. */ + 0x80,0x40,0x40,0x60,0x20,0x20,0x10,0x10,0x08,0x08,0x0C,0x04,0x04,0x02,0x00, /* Code for char num 47. */ + 0x00,0x38,0x44,0x82,0x82,0x82,0x82,0x82,0x82,0x82,0x44,0x38,0x00,0x00,0x00, /* Code for char num 48. */ + 0x00,0x10,0x1E,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0xFE,0x00,0x00,0x00, /* Code for char num 49. */ + 0x00,0x3E,0x42,0x40,0x40,0x40,0x20,0x10,0x08,0x04,0x02,0x7E,0x00,0x00,0x00, /* Code for char num 50. */ + 0x00,0x3C,0x40,0x40,0x40,0x60,0x38,0x40,0x40,0x40,0x40,0x3C,0x00,0x00,0x00, /* Code for char num 51. */ + 0x00,0x20,0x30,0x28,0x24,0x24,0x22,0x21,0x7F,0x20,0x20,0x20,0x00,0x00,0x00, /* Code for char num 52. */ + 0x00,0x7C,0x04,0x04,0x04,0x1C,0x20,0x40,0x40,0x40,0x20,0x3C,0x00,0x00,0x00, /* Code for char num 53. */ + 0x00,0x78,0x04,0x04,0x02,0x3A,0x46,0x82,0x82,0x82,0x44,0x38,0x00,0x00,0x00, /* Code for char num 54. */ + 0x00,0xFE,0x80,0x40,0x20,0x20,0x10,0x10,0x08,0x08,0x04,0x04,0x00,0x00,0x00, /* Code for char num 55. */ + 0x00,0x3C,0x42,0x42,0x42,0x24,0x1C,0x62,0x42,0x42,0x42,0x3C,0x00,0x00,0x00, /* Code for char num 56. */ + 0x00,0x38,0x44,0x82,0x82,0x82,0xC4,0xB8,0x80,0x40,0x40,0x3C,0x00,0x00,0x00, /* Code for char num 57. */ + 0x00,0x00,0x00,0x00,0x18,0x18,0x00,0x00,0x00,0x00,0x18,0x18,0x00,0x00,0x00, /* Code for char num 58. */ + 0x00,0x00,0x00,0x00,0x18,0x18,0x00,0x00,0x00,0x00,0x18,0x18,0x10,0x08,0x00, /* Code for char num 59. */ + 0x00,0x00,0x00,0x00,0x80,0x60,0x10,0x0C,0x0C,0x10,0x60,0x80,0x00,0x00,0x00, /* Code for char num 60. */ + 0x00,0x00,0x00,0x00,0x00,0x00,0xFE,0x00,0x00,0xFE,0x00,0x00,0x00,0x00,0x00, /* Code for char num 61. */ + 0x00,0x00,0x00,0x00,0x02,0x0C,0x10,0x60,0x60,0x10,0x0C,0x02,0x00,0x00,0x00, /* Code for char num 62. */ + 0x00,0x3E,0x42,0x42,0x40,0x20,0x10,0x08,0x08,0x00,0x08,0x08,0x00,0x00,0x00, /* Code for char num 63. */ + 0x00,0x78,0x84,0xE2,0x92,0x8A,0x8A,0xCA,0xCA,0xB2,0xA6,0x3C,0x00,0x00,0x00, /* Code for char num 64. */ + 0x00,0x00,0x10,0x38,0x28,0x28,0x44,0x44,0xFE,0x82,0x82,0x82,0x00,0x00,0x00, /* Code for char num 65. */ + 0x00,0x00,0x3E,0x42,0x42,0x22,0x1E,0x22,0x42,0x42,0x42,0x3E,0x00,0x00,0x00, /* Code for char num 66. */ + 0x00,0x00,0xF8,0x06,0x02,0x01,0x01,0x01,0x01,0x02,0x06,0xF8,0x00,0x00,0x00, /* Code for char num 67. */ + 0x00,0x00,0x3E,0x42,0x82,0x82,0x82,0x82,0x82,0x82,0x42,0x3E,0x00,0x00,0x00, /* Code for char num 68. */ + 0x00,0x00,0xFE,0x02,0x02,0x02,0x02,0x7E,0x02,0x02,0x02,0xFE,0x00,0x00,0x00, /* Code for char num 69. */ + 0x00,0x00,0xFE,0x02,0x02,0x02,0x02,0x7E,0x02,0x02,0x02,0x02,0x00,0x00,0x00, /* Code for char num 70. */ + 0x00,0x00,0xF8,0x06,0x02,0x01,0x01,0xE1,0x81,0x82,0x86,0xF8,0x00,0x00,0x00, /* Code for char num 71. */ + 0x00,0x00,0x42,0x42,0x42,0x42,0x42,0x7E,0x42,0x42,0x42,0x42,0x00,0x00,0x00, /* Code for char num 72. */ + 0x00,0x00,0xFE,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0xFE,0x00,0x00,0x00, /* Code for char num 73. */ + 0x00,0x00,0x3C,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x1E,0x00,0x00,0x00, /* Code for char num 74. */ + 0x00,0x00,0x42,0x22,0x12,0x0A,0x06,0x0A,0x12,0x22,0x42,0x82,0x00,0x00,0x00, /* Code for char num 75. */ + 0x00,0x00,0x02,0x02,0x02,0x02,0x02,0x02,0x02,0x02,0x02,0xFE,0x00,0x00,0x00, /* Code for char num 76. */ + 0x00,0x00,0x63,0x63,0x63,0x55,0x55,0x55,0x4D,0x49,0x41,0x41,0x00,0x00,0x00, /* Code for char num 77. */ + 0x00,0x00,0x82,0x86,0x8A,0x8A,0x92,0x92,0xA2,0xA2,0xC2,0x82,0x00,0x00,0x00, /* Code for char num 78. */ + 0x00,0x00,0x3C,0x42,0x81,0x81,0x81,0x81,0x81,0x81,0x42,0x3C,0x00,0x00,0x00, /* Code for char num 79. */ + 0x00,0x00,0x3E,0x42,0x42,0x42,0x62,0x1E,0x02,0x02,0x02,0x02,0x00,0x00,0x00, /* Code for char num 80. */ + 0x00,0x00,0x3C,0x42,0x81,0x81,0x81,0x81,0x81,0x81,0x42,0x3C,0x60,0x80,0x00, /* Code for char num 81. */ + 0x00,0x00,0x3E,0x42,0x42,0x42,0x22,0x1E,0x12,0x22,0x42,0x82,0x00,0x00,0x00, /* Code for char num 82. */ + 0x00,0x00,0x7C,0x42,0x02,0x06,0x1C,0x20,0x40,0x40,0x42,0x3E,0x00,0x00,0x00, /* Code for char num 83. */ + 0x00,0x00,0xFE,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x00,0x00,0x00, /* Code for char num 84. */ + 0x00,0x00,0x82,0x82,0x82,0x82,0x82,0x82,0x82,0x82,0x44,0x3C,0x00,0x00,0x00, /* Code for char num 85. */ + 0x00,0x00,0x82,0x82,0x82,0x82,0x44,0x44,0x28,0x28,0x38,0x10,0x00,0x00,0x00, /* Code for char num 86. */ + 0x00,0x00,0x82,0x82,0x92,0x92,0xAA,0xAA,0xAA,0xAA,0x64,0x44,0x00,0x00,0x00, /* Code for char num 87. */ + 0x00,0x00,0x82,0x82,0x44,0x28,0x10,0x10,0x28,0x44,0x82,0x82,0x00,0x00,0x00, /* Code for char num 88. */ + 0x00,0x00,0x82,0x82,0x44,0x44,0x28,0x10,0x10,0x10,0x10,0x10,0x00,0x00,0x00, /* Code for char num 89. */ + 0x00,0x00,0xFF,0x80,0x40,0x20,0x10,0x08,0x04,0x02,0x01,0xFF,0x00,0x00,0x00, /* Code for char num 90. */ + 0xF8,0x08,0x08,0x08,0x08,0x08,0x08,0x08,0x08,0x08,0x08,0x08,0x08,0xF8,0x00, /* Code for char num 91. */ + 0x02,0x04,0x04,0x04,0x08,0x08,0x10,0x10,0x20,0x20,0x20,0x40,0x40,0x80,0x00, /* Code for char num 92. */ + 0x3E,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x3E,0x00, /* Code for char num 93. */ + 0x00,0x10,0x10,0x10,0x28,0x28,0x44,0x44,0x44,0x82,0x00,0x00,0x00,0x00,0x00, /* Code for char num 94. */ + 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xFE,0x00,0x00, /* Code for char num 95. */ + 0x10,0x20,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, /* Code for char num 96. */ + 0x00,0x00,0x00,0x00,0x3C,0x40,0x40,0x78,0x44,0x42,0x62,0xDC,0x00,0x00,0x00, /* Code for char num 97. */ + 0x02,0x02,0x02,0x02,0x7A,0x46,0x82,0x82,0x82,0x82,0x46,0x3A,0x00,0x00,0x00, /* Code for char num 98. */ + 0x00,0x00,0x00,0x00,0xF8,0x04,0x02,0x02,0x02,0x02,0x04,0xF8,0x00,0x00,0x00, /* Code for char num 99. */ + 0x80,0x80,0x80,0x80,0xB8,0xC4,0x82,0x82,0x82,0x82,0xC4,0xBC,0x00,0x00,0x00, /* Code for char num 100. */ + 0x00,0x00,0x00,0x00,0x38,0x44,0x42,0x7E,0x02,0x02,0x04,0x78,0x00,0x00,0x00, /* Code for char num 101. */ + 0xF0,0x08,0x08,0x08,0xFE,0x08,0x08,0x08,0x08,0x08,0x08,0x08,0x00,0x00,0x00, /* Code for char num 102. */ + 0x00,0x00,0x00,0x00,0xB8,0xC4,0x82,0x82,0x82,0x82,0xC4,0xBC,0x80,0x40,0x3C, /* Code for char num 103. */ + 0x02,0x02,0x02,0x02,0x3A,0x46,0x42,0x42,0x42,0x42,0x42,0x42,0x00,0x00,0x00, /* Code for char num 104. */ + 0x18,0x18,0x00,0x00,0x1E,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x00,0x00,0x00, /* Code for char num 105. */ + 0x30,0x30,0x00,0x00,0x3C,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x1E, /* Code for char num 106. */ + 0x02,0x02,0x02,0x02,0x42,0x22,0x12,0x0E,0x0A,0x12,0x22,0x42,0x00,0x00,0x00, /* Code for char num 107. */ + 0x1E,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x00,0x00,0x00, /* Code for char num 108. */ + 0x00,0x00,0x00,0x00,0xDA,0xB6,0x92,0x92,0x92,0x92,0x92,0x92,0x00,0x00,0x00, /* Code for char num 109. */ + 0x00,0x00,0x00,0x00,0x3A,0x46,0x42,0x42,0x42,0x42,0x42,0x42,0x00,0x00,0x00, /* Code for char num 110. */ + 0x00,0x00,0x00,0x00,0x38,0x44,0x82,0x82,0x82,0x82,0x44,0x38,0x00,0x00,0x00, /* Code for char num 111. */ + 0x00,0x00,0x00,0x00,0x7A,0x46,0x82,0x82,0x82,0x82,0x46,0x3A,0x02,0x02,0x02, /* Code for char num 112. */ + 0x00,0x00,0x00,0x00,0xB8,0xC4,0x82,0x82,0x82,0x82,0xC4,0xBC,0x80,0x80,0x80, /* Code for char num 113. */ + 0x00,0x00,0x00,0x00,0xF4,0x8C,0x04,0x04,0x04,0x04,0x04,0x04,0x00,0x00,0x00, /* Code for char num 114. */ + 0x00,0x00,0x00,0x00,0x7C,0x02,0x02,0x0C,0x30,0x40,0x42,0x3E,0x00,0x00,0x00, /* Code for char num 115. */ + 0x00,0x00,0x08,0x08,0xFE,0x08,0x08,0x08,0x08,0x08,0x08,0xF0,0x00,0x00,0x00, /* Code for char num 116. */ + 0x00,0x00,0x00,0x00,0x42,0x42,0x42,0x42,0x42,0x42,0x62,0x5C,0x00,0x00,0x00, /* Code for char num 117. */ + 0x00,0x00,0x00,0x00,0x82,0x82,0x82,0x44,0x44,0x28,0x28,0x10,0x00,0x00,0x00, /* Code for char num 118. */ + 0x00,0x00,0x00,0x00,0x82,0x92,0xAA,0xAA,0xAA,0xAA,0x44,0x44,0x00,0x00,0x00, /* Code for char num 119. */ + 0x00,0x00,0x00,0x00,0x82,0x44,0x28,0x10,0x10,0x28,0x44,0x82,0x00,0x00,0x00, /* Code for char num 120. */ + 0x00,0x00,0x00,0x00,0x82,0x82,0x82,0x44,0x44,0x28,0x28,0x10,0x10,0x0C,0x00, /* Code for char num 121. */ + 0x00,0x00,0x00,0x00,0xFE,0x80,0x40,0x20,0x10,0x08,0x04,0xFE,0x00,0x00,0x00, /* Code for char num 122. */ + 0xE0,0x10,0x10,0x10,0x10,0x10,0x10,0x0C,0x10,0x10,0x10,0x10,0x10,0xE0,0x00, /* Code for char num 123. */ + 0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x00, /* Code for char num 124. */ + 0x0E,0x10,0x10,0x10,0x10,0x10,0x10,0x60,0x10,0x10,0x10,0x10,0x10,0x0E,0x00, /* Code for char num 125. */ + 0x00,0x00,0x00,0x00,0x00,0x00,0x62,0x92,0x8C,0x00,0x00,0x00,0x00,0x00,0x00, /* Code for char num 126. */ + 0x00,0x00,0x00,0x07,0x05,0x05,0x05,0x05,0x05,0x05,0x07,0x00,0x00,0x00,0x00 /* Code for char num 127. */ +}; + + +#endif /* FONT_9x15_H_H */ \ No newline at end of file diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/glcd_mps3.h b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/glcd_mps3.h new file mode 100644 index 0000000..c2810c0 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/glcd_mps3.h @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLCD_MPS3_H +#define GLCD_MPS3_H + +#include + +/****************************************************************************** + Color coding + GLCD is coded: 15..11 red, 10..5 green, 4..0 blue (unsigned short) + GLCD_R5, GLCD_G6, GLCD_B5 + original coding: 17..12 red, 11..6 green, 5..0 blue + ORG_R6, ORG_G6, ORG_B6 + + ORG_R1..5 = GLCD_R0..4, ORG_R0 = GLCD_R4 + ORG_G0..5 = GLCD_G0..5, + ORG_B1..5 = GLCD_B0..4, ORG_B0 = GLCD_B4 + + GLCD RGB color definitions +******************************************************************************/ +#define Black 0x0000 /* 0, 0, 0 */ +#define Navy 0x000F /* 0, 0, 128 */ +#define DarkGreen 0x03E0 /* 0, 128, 0 */ +#define DarkCyan 0x03EF /* 0, 128, 128 */ +#define Maroon 0x7800 /* 128, 0, 0 */ +#define Purple 0x780F /* 128, 0, 128 */ +#define Olive 0x7BE0 /* 128, 128, 0 */ +#define LightGrey 0xC618 /* 192, 192, 192 */ +#define DarkGrey 0x7BEF /* 128, 128, 128 */ +#define Blue 0x001F /* 0, 0, 255 */ +#define Green 0x07E0 /* 0, 255, 0 */ +#define Cyan 0x07FF /* 0, 255, 255 */ +#define Red 0xF800 /* 255, 0, 0 */ +#define Magenta 0xF81F /* 255, 0, 255 */ +#define Yellow 0xFFE0 /* 255, 255, 0 */ +#define White 0xFFFF /* 255, 255, 255 */ + +/************************** Orientation configuration ************************/ +#ifndef LANDSCAPE +#define LANDSCAPE 1 /* 1 for landscape, 0 for portrait. */ +#endif +#ifndef ROTATE180 +#define ROTATE180 1 /* 1 to rotate the screen for 180 deg. */ +#endif + +/*------------------------- Speed dependant settings -------------------------*/ + +/* If processor works on high frequency delay has to be increased, it can be + increased by factor 2^N by this constant. */ +#define DELAY_2N 8 + +/*---------------------- Graphic LCD size definitions ------------------------*/ +#if (LANDSCAPE == 1) + #define GLCD_WIDTH 320 /* Screen Width (in pixels). */ + #define GLCD_HEIGHT 240 /* Screen Hight (in pixels). */ +#else + #define GLCD_WIDTH 240 /* Screen Width (in pixels). */ + #define GLCD_HEIGHT 320 /* Screen Hight (in pixels). */ +#endif + +#define BPP 16 /* Bits per pixel. */ +#define BYPP ((BPP+7)/8) /* Bytes per pixel. */ + + +/** + * @brief Initialize the Himax LCD with HX8347-D LCD Controller. + */ +void GLCD_Initialize(void); + +/** + * @brief Set draw window region to whole screen. + */ +void GLCD_WindowMax(void); + +/** + * @brief Set draw window region. + * @param[in] x Horizontal position. + * @param[in] y Vertical position. + * @param[in] w Window width in pixel. + * @param[in] h Window height in pixels. + */ +void GLCD_SetWindow(unsigned int x, unsigned int y, + unsigned int w, unsigned int h); + +/** + * @brief Set foreground color. + * @param[in] color Foreground color. + */ +void GLCD_SetTextColor(unsigned short color); + +/** + * @brief Set background color. + * @param[in] color Background color. + */ +void GLCD_SetBackColor(unsigned short color); + +/** + * @brief Clear display. + * @param[in] color Display clearing color. + * + */ +void GLCD_Clear(unsigned short color); + +/** + * @brief Draw character on given position. + * @param[in] x Horizontal position. + * @param[in] y Vertical position. + * @param[in] cw Character width in pixel. + * @param[in] ch Character height in pixels. + * @param[in] c Pointer to character bitmap. + * + */ +void GLCD_DrawChar(unsigned int x, unsigned int y, + unsigned int cw, unsigned int ch, + unsigned char *c); + +/** + * @brief Display character on given line. + * @param[in] ln Line number. + * @param[in] col Column number. + * @param[in] fi Font index (0 = 9x15). + * @param[in] c ASCII character. + */ +void GLCD_DisplayChar(unsigned int ln, unsigned int col, + unsigned char fi, unsigned char c); + + +/** + * @brief Display string on given line. + * @param[in] ln Line number. + * @param[in] col Column number. + * @param[in] fi Font index (0 = 9x15). + * @param[in] s Pointer to string. + */ +void GLCD_DisplayString(unsigned int ln, unsigned int col, + unsigned char fi, char *s); + +/** + * @brief Clear given line. + * @param[in] ln: Line number. + * @param[in] fi Font index (0 = 9x15). + */ +void GLCD_ClearLn(unsigned int ln, unsigned char fi); + +/** + * @brief Display graphical bitmap image at position x horizontally and y + * vertically. This function is optimized for 16 bits per pixel + * format, it has to be adapted for any other format. + * @param[in] x Horizontal position. + * @param[in] y Vertical position. + * @param[in] w Width of bitmap. + * @param[in] h Height of bitmap. + * @param[in] bitmap Address at which the bitmap data resides. + */ +void GLCD_Bitmap(unsigned int x, unsigned int y, + unsigned int w, unsigned int h, + unsigned short *bitmap); + +/** + * @brief Displays an 8 bit image, conversion to the LCD's + * 16 bit codec is done on the fly. + * @param[in] data Pointer to the full sized image data. + * @param[in] width Image width. + * @param[in] height Image height. + * @param[in] channels Number of channels in the image. + * @param[in] pos_x Start x position for the LCD. + * @param[in] pos_y Start y position for the LCD. + * @param[in] downsample_factor Factor by which the image + * is downsampled by. + */ +void GLCD_Image(void *data, const uint32_t width, + const uint32_t height, const uint32_t channels, + const uint32_t pos_x, const uint32_t pos_y, + const uint32_t downsample_factor); + +/** + * @brief Draw box filled with color. + * @param[in] x Horizontal position. + * @param[in] y Vertical position. + * @param[in] w Window width in pixels. + * @param[in] h Window height in pixels. + * @param[in] color Box color. + */ +void GLCD_Box(unsigned int x, unsigned int y, + unsigned int w, unsigned int h, + unsigned short color); + +#endif /* GLCD_MPS3_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/smm_mps3.h b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/smm_mps3.h new file mode 100644 index 0000000..1c0e0f2 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/include/smm_mps3.h @@ -0,0 +1,615 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SMM_MPS3_H +#define SMM_MPS3_H + +#include "cmsis.h" /* Device specific header file. */ +#include "peripheral_memmap.h" /* Peripheral memory map definitions. */ + +#if defined ( __CC_ARM ) +#pragma anon_unions +#endif + +/******************************************************************************/ +/* FPGA System Register declaration */ +/******************************************************************************/ + +typedef struct +{ + __IO uint32_t LED; /* Offset: 0x000 (R/W) LED connections + * [31:2] : Reserved + * [1:0] : LEDs + */ + uint32_t RESERVED1[1]; + __IO uint32_t BUTTON; /* Offset: 0x008 (R/W) Buttons + * [31:2] : Reserved + * [1:0] : Buttons + */ + uint32_t RESERVED2[1]; + __IO uint32_t CLK1HZ; /* Offset: 0x010 (R/W) 1Hz up counter */ + __IO uint32_t CLK100HZ; /* Offset: 0x014 (R/W) 100Hz up counter */ + __IO uint32_t COUNTER; /* Offset: 0x018 (R/W) Cycle Up Counter + * Increments when 32-bit prescale counter reach zero + */ + __IO uint32_t PRESCALE; /* Offset: 0x01C (R/W) Prescaler + * Bit[31:0] : reload value for prescale counter + */ + __IO uint32_t PSCNTR; /* Offset: 0x020 (R/W) 32-bit Prescale counter + * current value of the pre-scaler counter + * The Cycle Up Counter increment when the prescale down counter reach 0 + * The pre-scaler counter is reloaded with PRESCALE after reaching 0. + */ + uint32_t RESERVED3[1]; + __IO uint32_t SWITCHES; /* Offset: 0x028 (R/W) Switches + * [31:8] : Reserved + * [7:0] : Switches + */ + uint32_t RESERVED4[8]; + __IO uint32_t MISC; /* Offset: 0x04C (R/W) Misc control + * [31:10] : Reserved + * [9] : + * [8] : + * [7] : ADC_SPI_nCS + * [6] : CLCD_BL_CTRL + * [5] : CLCD_RD + * [4] : CLCD_RS + * [3] : CLCD_RESET + * [2] : SHIELD_1_SPI_nCS + * [1] : SHIELD_0_SPI_nCS + * [0] : CLCD_CS + */ +} MPS3_FPGAIO_TypeDef; + +/* MISC register bit definitions. */ + +#define CLCD_CS_Pos 0 +#define CLCD_CS_Msk (1UL< CONTROL + * TX Enable + * <0=> TX disabled + * <1=> TX enabled + * TX IRQ Enable + * <0=> TX IRQ disabled + * <1=> TX IRQ enabled + * RX Enable + * <0=> RX disabled + * <1=> RX enabled + * RX IRQ Enable + * <0=> RX IRQ disabled + * <1=> RX IRQ enabled + * TX Buffer Water Level + * <0=> / IRQ triggers when any space available + * <1=> / IRQ triggers when more than 1 space available + * <2=> / IRQ triggers when more than 2 space available + * <3=> / IRQ triggers when more than 3 space available + * <4=> Undefined! + * <5=> Undefined! + * <6=> Undefined! + * <7=> Undefined! + * RX Buffer Water Level + * <0=> Undefined! + * <1=> / IRQ triggers when less than 1 space available + * <2=> / IRQ triggers when less than 2 space available + * <3=> / IRQ triggers when less than 3 space available + * <4=> / IRQ triggers when less than 4 space available + * <5=> Undefined! + * <6=> Undefined! + * <7=> Undefined! + * FIFO reset + * <0=> Normal operation + * <1=> FIFO reset + * Audio Codec reset + * <0=> Normal operation + * <1=> Assert audio Codec reset + */ + /*!< Offset: 0x004 STATUS Register (R/ ) */ + __I uint32_t STATUS; /* STATUS + * TX Buffer alert + * <0=> TX buffer don't need service yet + * <1=> TX buffer need service + * RX Buffer alert + * <0=> RX buffer don't need service yet + * <1=> RX buffer need service + * TX Buffer Empty + * <0=> TX buffer have data + * <1=> TX buffer empty + * TX Buffer Full + * <0=> TX buffer not full + * <1=> TX buffer full + * RX Buffer Empty + * <0=> RX buffer have data + * <1=> RX buffer empty + * RX Buffer Full + * <0=> RX buffer not full + * <1=> RX buffer full + */ + union { + /*!< Offset: 0x008 Error Status Register (R/ ) */ + __I uint32_t ERROR; /* ERROR + * TX error + * <0=> Okay + * <1=> TX overrun/underrun + * RX error + * <0=> Okay + * <1=> RX overrun/underrun + */ + /*!< Offset: 0x008 Error Clear Register ( /W) */ + __O uint32_t ERRORCLR; /* ERRORCLR + * TX error + * <0=> Okay + * <1=> Clear TX error + * RX error + * <0=> Okay + * <1=> Clear RX error + */ + }; + /*!< Offset: 0x00C Divide ratio Register (R/W) */ + __IO uint32_t DIVIDE; /* Divide ratio for Left/Right clock + * TX error (default 0x80) + */ + /*!< Offset: 0x010 Transmit Buffer ( /W) */ + __O uint32_t TXBUF; /* Transmit buffer + * Right channel + * Left channel + */ + + /*!< Offset: 0x014 Receive Buffer (R/ ) */ + __I uint32_t RXBUF; /* Receive buffer + * Right channel + * Left channel + */ + uint32_t RESERVED1[186]; + __IO uint32_t ITCR; /* Integration Test Control Register + * ITEN + * <0=> Normal operation + * <1=> Integration Test mode enable + */ + __O uint32_t ITIP1; /* Integration Test Input Register 1 + * SDIN + */ + __O uint32_t ITOP1; /* Integration Test Output Register 1 + * SDOUT + * SCLK + * LRCK + * IRQOUT + */ +} MPS3_I2S_TypeDef; + +#define I2S_CONTROL_TXEN_Pos 0 +#define I2S_CONTROL_TXEN_Msk (1UL< +#include + +/* Container for timestamp up-counters. */ +typedef struct _mps3_time_counter { + uint32_t counter_1Hz; + uint32_t counter_100Hz; + + /* Running at FPGA clock rate. See GetMPS3CoreClock(). */ + uint32_t counter_fpga; + + /* Running at processor core's internal clock rate, triggered by SysTick. */ + uint64_t counter_systick; +} mps3_time_counter; + +/** + * @brief Resets the counters. + */ +void timer_reset(void); + +/** + * @brief Gets the current counter values. + * @returns Mps3 timer counter. + **/ +mps3_time_counter get_time_counter(void); + +/** + * @brief Gets the duration elapsed between two counters in milliseconds. + * @param[in] start Pointer to mps3_time_counter value at start time. + * @param[in] end Pointer to mps3_time_counter value at end. + * @returns Difference in milliseconds between the two give counters + * expressed as an unsigned integer. + **/ +uint32_t get_duration_milliseconds(mps3_time_counter *start, + mps3_time_counter *end); + +/** + * @brief Gets the duration elapsed between two counters in microseconds. + * @param[in] start Pointer to mps3_time_counter value at start time. + * @param[in] end Pointer to mps3_time_counter value at end. + * @returns Difference in microseconds between the two give counters + * expressed as an unsigned integer. + **/ +uint32_t get_duration_microseconds(mps3_time_counter *start, + mps3_time_counter *end); + +/** + * @brief Gets the cycle counts elapsed between start and end. + * @param[in] start Pointer to mps3_time_counter value at start time. + * @param[in] end Pointer to mps3_time_counter value at end. + * @return Difference in counter values as 32 bit unsigned integer. + **/ +uint64_t get_cycle_count_diff(mps3_time_counter *start, + mps3_time_counter *end); + +/** + * @brief Enables or triggers cycle counting mechanism, if required + * by the platform. + **/ +void start_cycle_counter(void); + +/** + * @brief Stops cycle counting mechanism, if required by the platform. + **/ +void stop_cycle_counter(void); + +#endif /* TIMER_MPS3_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/timer_mps3.c b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/timer_mps3.c new file mode 100644 index 0000000..0a3a8b1 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/timer_mps3.c @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "timer_mps3.h" + +#include "bsp_core_log.h" +#include "device_mps3.h" + +void timer_reset(void) +{ + MPS3_FPGAIO->CLK1HZ = 0; + MPS3_FPGAIO->CLK100HZ = 0; + MPS3_FPGAIO->COUNTER = 0; + + if (0 != Init_SysTick()) { + printf_err("Failed to initialise system tick config\n"); + } + debug("system tick config ready\n"); +} + +mps3_time_counter get_time_counter(void) +{ + mps3_time_counter t = { + .counter_1Hz = MPS3_FPGAIO->CLK1HZ, + .counter_100Hz = MPS3_FPGAIO->CLK100HZ, + .counter_fpga = MPS3_FPGAIO->COUNTER, + .counter_systick = Get_SysTick_Cycle_Count() + }; + debug("Timestamp:\ + \n\tCounter 1 Hz: %u\ + \n\tCounter 100 Hz: %u\ + \n\tCounter FPGA: %u\ + \n\tCounter CPU: %llu\n", + t.counter_1Hz, t.counter_100Hz, t.counter_fpga, t.counter_systick); + return t; +} + +/** + * Please note, that there are no checks for overflow in this function => if + * the time elapsed has been big (in days) this could happen and is currently + * not handled. + **/ +uint32_t get_duration_milliseconds(mps3_time_counter *start, + mps3_time_counter *end) +{ + uint32_t time_elapsed = 0; + if (end->counter_100Hz > start->counter_100Hz) { + time_elapsed = (end->counter_100Hz - start->counter_100Hz) * 10; + } else { + time_elapsed = (end->counter_1Hz - start->counter_1Hz) * 1000 + + ((0xFFFFFFFF - start->counter_100Hz) + end->counter_100Hz + 1) * 10; + } + + /* If the time elapsed is less than 100ms, use microseconds count to be + * more precise */ + if (time_elapsed < 100) { + debug("Using the microsecond function instead..\n"); + return get_duration_microseconds(start, end)/1000; + } + + return time_elapsed; +} + +/** + * Like the microsecond counterpart, this function could return wrong results when + * the counter (MAINCLK) overflows. There are no overflow counters available. + **/ +uint32_t get_duration_microseconds(mps3_time_counter *start, + mps3_time_counter *end) +{ + const int divisor = GetMPS3CoreClock()/1000000; + uint32_t time_elapsed = 0; + if (end->counter_fpga > start->counter_fpga) { + time_elapsed = (end->counter_fpga - start->counter_fpga)/divisor; + } else { + time_elapsed = ((0xFFFFFFFF - end->counter_fpga) + + start->counter_fpga + 1)/divisor; + } + return time_elapsed; +} + +uint64_t get_cycle_count_diff(mps3_time_counter *start, + mps3_time_counter *end) +{ + if (start->counter_systick > end->counter_systick) { + warn("start > end; counter might have overflown\n"); + } + return end->counter_systick - start->counter_systick; +} + +void start_cycle_counter(void) +{ + /* Nothing to do for FPGA */ +} + +void stop_cycle_counter(void) +{ + /* Nothing to do for FPGA */ +} diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/uart_stdout.c b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/uart_stdout.c new file mode 100644 index 0000000..1bf8291 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/mps3/uart_stdout.c @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "uart_stdout.h" + +#include "device_mps3.h" + +#include + +#define CNTLQ 0x11 +#define CNTLS 0x13 +#define DEL 0x7F +#define BACKSPACE 0x08 +#define CR 0x0D +#define LF 0x0A +#define ESC 0x1B + +void UartStdOutInit(void) +{ + /* NOTE: SystemCoreClock should have been set before initialising UART. */ + CMSDK_UART0->BAUDDIV = SystemCoreClock / 115200; /* => (25 or 32 MHz) / (115200 bps). */ + CMSDK_UART0->CTRL = ((1ul << 0) | /* TX enable. */ + (1ul << 1) ); /* RX enable. */ + return; +} + +unsigned char UartPutc(unsigned char my_ch) +{ + while ((CMSDK_UART0->STATE & 1)); /* Wait if Transmit Holding register is full. */ + + if (my_ch == '\n') { + CMSDK_UART0->DATA = '\r'; + while ((CMSDK_UART0->STATE & 1)); /* Wait if Transmit Holding register is full. */ + } + + CMSDK_UART0->DATA = my_ch; /* Write to transmit holding register. */ + return (my_ch); +} + +unsigned char UartGetc(void) +{ + unsigned char my_ch; + unsigned int cnt; + + /* Wait if Receive Holding register is empty. */ + while (0 == (CMSDK_UART0->STATE & 2)) { + cnt = MPS3_FPGAIO->CLK100HZ / 50; + if (cnt & 0x8) { + MPS3_FPGAIO->LED = 0x01 << (cnt & 0x7); + } + else { + MPS3_FPGAIO->LED = 0x80 >> (cnt & 0x7); + } + } + + my_ch = CMSDK_UART0->DATA; + + /* Convert CR to LF. */ + if(my_ch == '\r') { + my_ch = '\n'; + } + + return (my_ch); +} + +bool GetLine(char *lp, unsigned int len) +{ + unsigned int cnt = 0; + char c; + + do { + c = UartGetc (); + switch (c) { + case CNTLQ: /* Ignore Control S/Q. */ + case CNTLS: + break; + + case BACKSPACE: + case DEL: + if (cnt == 0) { + break; + } + cnt--; /* Decrement count. */ + lp--; /* Decrement line pointer. */ + UartPutc (0x08); /* Echo backspace. */ + UartPutc (' '); + UartPutc (0x08); + fflush (stdout); + break; + + case ESC: + case 0: + *lp = 0; /* ESC - stop editing line. */ + return false; + + case CR: /* CR - done, stop editing line. */ + *lp = c; + lp++; /* Increment line pointer */ + cnt++; /* and count. */ + c = LF; + default: + UartPutc (*lp = c); /* Echo and store character. */ + fflush (stdout); + lp++; /* Increment line pointer */ + cnt++; /* and count. */ + break; + } + } while (cnt < len - 2 && c != LF); /* Check limit and CR. */ + *lp = 0; /* Mark end of string. */ + + return true; +} + +void UartEndSimulation(int code) +{ + UartPutc((char) 0x4); /* End of simulation */ + UartPutc((char) code); /* End of simulation */ + while(1); +} diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/include/stubs_fvp.h b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/include/stubs_fvp.h new file mode 100644 index 0000000..a21f2d2 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/include/stubs_fvp.h @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BSP_PACK_FASTMODEL_H +#define BSP_PACK_FASTMODEL_H + +#include "cmsis.h" /* device specific header file */ +#include "peripheral_memmap.h" /* peripheral memory map definitions */ + +/****************************************************************************/ +/* Definitions and stub functions for modules currently */ +/* unavailable on the model */ +/****************************************************************************/ +#define GLCD_WIDTH 320 +#define GLCD_HEIGHT 240 +#define Black 0x0000 /* 0, 0, 0 */ +#define White 0xFFFF /* 255, 255, 255 */ + +/*********************** Clock related functions *****************************/ +uint32_t GetCoreClock(void); + +/************************ GLCD related functions ****************************/ +/** + * @brief Initialize the Himax LCD with HX8347-D LCD Controller + * @return none + */ +void GLCD_Initialize(void); + +/** + * @brief Display graphical bitmap image at position x horizontally and y + * vertically. This function is optimized for 16 bits per pixel + * format, it has to be adapted for any other format. + * @param[in] x horizontal position. + * @param[in] y vertical position. + * @param[in] w width of bitmap. + * @param[in] h height of bitmap. + * @param[in] bitmap address at which the bitmap data resides. + * @return none + */ +void GLCD_Bitmap(unsigned int x, unsigned int y, + unsigned int w, unsigned int h, + unsigned short *bitmap); + +/** + * @brief Displays an 8 bit image, conversion to the LCD's + * 16 bit codec is done on the fly. + * @param[in] data pointer to the full sized image data. + * @param[in] width image width. + * @param[in] height image height. + * @param[in] channels number of channels in the image. + * @param[in] pos_x start x position for the LCD. + * @param[in] pos_y start y position for the LCD. + * @param[in] downsample_factor factor by which the image + * is downsampled by. + * @return none + */ +void GLCD_Image(void *data, const uint32_t width, + const uint32_t height, const uint32_t channels, + const uint32_t pos_x, const uint32_t pos_y, + const uint32_t downsample_factor); + +/** + * @brief Clear display + * @param[in] color display clearing color + * @return none + */ +void GLCD_Clear(unsigned short color); + +/** + * @brief Set foreground color + * @param[in] color foreground color + * @return none + */ +void GLCD_SetTextColor(unsigned short color); + +/** + * @brief Display character on given line + * @param[in] ln line number + * @param[in] col column number + * @param[in] fi font index (0 = 9x15) + * @param[in] c ASCII character + * @return none + */ +void GLCD_DisplayChar(unsigned int ln, unsigned int col, + unsigned char fi, unsigned char c); + +/** + * @brief Display string on given line + * @param[in] ln line number + * @param[in] col column number + * @param[in] fi font index (0 = 9x15) + * @param[in] s pointer to string + * @return none + */ +void GLCD_DisplayString(unsigned int ln, unsigned int col, + unsigned char fi, char *s); + +/** + * @brief Draw box filled with color + * @param[in] x horizontal position + * @param[in] y: vertical position + * @param[in] w: window width in pixels + * @param[in] h: window height in pixels + * @param[in] color box color + * @return none + */ +void GLCD_Box(unsigned int x, unsigned int y, + unsigned int w, unsigned int h, + unsigned short color); + +#endif /* BSP_PACK_FASTMODEL_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/include/timer_fvp.h b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/include/timer_fvp.h new file mode 100644 index 0000000..c07a4eb --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/include/timer_fvp.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TIMER_FVP_H +#define TIMER_FVP_H + +#include "stubs_fvp.h" + +/* Container for timestamp for fastmodel. */ +typedef struct _fvp_time_counter { + uint64_t counter_systick; +} fvp_time_counter; + +/** + * @brief Resets the counters. + */ +void timer_reset(void); + +/** + * @brief Gets the current counter values. + * @returns counter struct. + **/ +fvp_time_counter get_time_counter(void); + +/** + * @brief Gets the cycle counts elapsed between start and end. + * @return difference in counter values as 32 bit unsigned integer. + */ +uint64_t get_cycle_count_diff(fvp_time_counter *start, fvp_time_counter *end); + +/** + * @brief Enables or triggers cycle counting mechanism, if required + * by the platform. + */ +void start_cycle_counter(void); + +/** + * @brief Stops cycle counting mechanism, if required by the platform. + */ +void stop_cycle_counter(void); + +#endif /* TIMER_FVP_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/stubs_fvp.c b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/stubs_fvp.c new file mode 100644 index 0000000..e5b2969 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/stubs_fvp.c @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "stubs_fvp.h" + +#include "bsp_core_log.h" + +uint32_t GetCoreClock(void) +{ + return 1; +} + +void GLCD_Initialize(void) {} + +void GLCD_Bitmap(unsigned int x, unsigned int y, + unsigned int w, unsigned int h, unsigned short *bitmap) +{ + UNUSED(x); + UNUSED(y); + UNUSED(w); + UNUSED(h); + UNUSED(bitmap); +} + +void GLCD_Image(void *data, const uint32_t width, const uint32_t height, + const uint32_t channels, const uint32_t pos_x, + const uint32_t pos_y, const uint32_t downsample_factor) +{ + UNUSED(data); + UNUSED(pos_x); + UNUSED(pos_y); + UNUSED(width); + UNUSED(height); + UNUSED(channels); + UNUSED(downsample_factor); + debug("image display: (x, y, w, h) = (%u, %u, %u, %u)\n", + pos_x, pos_y, width, height); + debug("image display: channels = %u, downsample factor = %u\n", + channels, downsample_factor); +} + +void GLCD_Clear(unsigned short color) +{ + UNUSED(color); +} + +void GLCD_SetTextColor(unsigned short color) +{ + UNUSED(color); +} + +void GLCD_DisplayChar (unsigned int ln, unsigned int col, unsigned char fi, + unsigned char c) +{ + UNUSED(ln); + UNUSED(col); + UNUSED(fi); + UNUSED(c); +} + +void GLCD_DisplayString(unsigned int ln, unsigned int col, unsigned char fi, + char *s) +{ + UNUSED(ln); + UNUSED(col); + UNUSED(fi); + UNUSED(s); + debug("text display: %s\n", s); +} + +void GLCD_Box(unsigned int x, unsigned int y, unsigned int w, unsigned int h, + unsigned short color) +{ + UNUSED(x); + UNUSED(y); + UNUSED(w); + UNUSED(h); + UNUSED(color); +} + +void LED_Initialize(uint32_t port) +{ + UNUSED(port); +} + +void LED_On(uint32_t num, uint32_t port) +{ + UNUSED(num); + UNUSED(port); + debug("LED %u ON\n", num); +} + +void LED_Off(uint32_t num, uint32_t port) +{ + UNUSED(num); + UNUSED(port); + debug("LED %u OFF\n", num); +} diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/timer_fvp.c b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/timer_fvp.c new file mode 100644 index 0000000..b7a7232 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/timer_fvp.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "timer_fvp.h" + +#include "irqs.h" +#include "bsp_core_log.h" + +fvp_time_counter get_time_counter(void) +{ + fvp_time_counter t = { + .counter_systick = Get_SysTick_Cycle_Count() + }; + debug("counter_systick: %llu\n", t.counter_systick); + return t; +} + +void timer_reset(void) +{ + if (0 != Init_SysTick()) { + printf_err("Failed to initialise system tick config\n"); + } + debug("system tick config ready\n"); +} + +uint64_t get_cycle_count_diff(fvp_time_counter *start, + fvp_time_counter *end) +{ + if (start->counter_systick > end->counter_systick) { + warn("start > end; counter might have overflown\n"); + } + return end->counter_systick - start->counter_systick; +} + +void start_cycle_counter(void) +{ + /* Add any custom requirement for this platform here */ +} + +void stop_cycle_counter(void) +{ + /* Add any custom requirement for this platform here */ +} diff --git a/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/uart_pl011.c b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/uart_pl011.c new file mode 100644 index 0000000..5c1ee06 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/bsp-packs/simple_platform/uart_pl011.c @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "uart_stdout.h" +#include "peripheral_memmap.h" /* peripheral memory map definitions */ + +#include +#include + +#define CNTLQ 0x11 +#define CNTLS 0x13 +#define DEL 0x7F +#define BACKSPACE 0x08 +#define CR 0x0D +#define LF 0x0A +#define ESC 0x1B + +#define UARTBASE (PL011_UART0_BASE) + +/*****************************************************************************/ +/* UART Control Register Locations */ +/*****************************************************************************/ +#define UART0_DR *((volatile unsigned *) UARTBASE) +#define UART0_RSR *((volatile unsigned *)(UARTBASE + 0x04)) +#define UART0_ECR *((volatile unsigned *)(UARTBASE + 0x04)) +#define UART0_LCRH *((volatile unsigned *)(UARTBASE + 0x2C)) +#define UART0_LCRM *((volatile unsigned *)(UARTBASE + 0x28)) +#define UART0_LCRL *((volatile unsigned *)(UARTBASE + 0x24)) +#define UART0_CR *((volatile unsigned *)(UARTBASE + 0x30)) +#define UART0_FR *((volatile unsigned *)(UARTBASE + 0x18)) +#define UART0_IIR *((volatile unsigned *)(UARTBASE + 0x1C)) +#define UART0_ICR *((volatile unsigned *)(UARTBASE + 0x44)) + +/*****************************************************************************/ +/* Received Status Register - RSR */ +/*****************************************************************************/ +#define RSR_OVERRUN_ERROR 0x08 +#define RSR_BREAK_ERROR 0x04 +#define RSR_PARITY_ERROR 0x02 +#define RSR_FRAMING_ERROR 0x01 + +/*****************************************************************************/ +/* Line Control High Byte Register - LCRH */ +/*****************************************************************************/ +#define LCRH_WORD_LENGTH_8 0x60 +#define LCRH_WORD_LENGTH_7 0x40 +#define LCRH_WORD_LENGTH_6 0x20 +#define LCRH_WORD_LENGTH_5 0x00 +#define LCRH_FIFO_ENABLED 0x10 +#define LCRH_2_STOP_BITS 0x08 +#define LCRH_EVEN_PARITY 0x04 +#define LCRH_PARITY_ENABLE 0x02 +#define LCRH_SEND_BREAK 0x01 + +/*****************************************************************************/ +/* Line Control Medium Byte Register - LCRM */ +/* This register specifies the high byte of the Baud rate divisor */ +/*****************************************************************************/ +#define LCRM_BAUD_460800 0x00 +#define LCRM_BAUD_230400 0x00 +#define LCRM_BAUD_115200 0x00 +#define LCRM_BAUD_76800 0x00 +#define LCRM_BAUD_57600 0x00 +#define LCRM_BAUD_38400 0x00 +#define LCRM_BAUD_19200 0x00 +#define LCRM_BAUD_14400 0x00 +#define LCRM_BAUD_9600 0x00 +#define LCRM_BAUD_2400 0x01 +#define LCRM_BAUD_1200 0x02 + +/*****************************************************************************/ +/* Line Control Low Byte Register - LCRL */ +/* This register specifies the low byte of the Baud rate divisor */ +/*****************************************************************************/ +#define LCRL_BAUD_460800 0x01 +#define LCRL_BAUD_230400 0x03 +#define LCRL_BAUD_115200 0x07 +#define LCRL_BAUD_76800 0x0B +#define LCRL_BAUD_57600 0x0F +#define LCRL_BAUD_38400 0xC +#define LCRL_BAUD_19200 0x2F +#define LCRL_BAUD_14400 0x3F +#define LCRL_BAUD_9600 0x5F +#define LCRL_BAUD_2400 0x7F +#define LCRL_BAUD_1200 0xFF + +/*****************************************************************************/ +/* Control Register - CR */ +/*****************************************************************************/ +#define CR_LOOP_BACK_EN 0x80 +#define CR_TIMEOUT_INT_EN 0x40 +#define CR_TX_INT_ENABLE 0x100 +#define CR_RX_INT_ENABLE 0x200 +#define CR_MODSTAT_INT_EN 0x08 +#define CR_UART_ENABLE 0x01 + +/*****************************************************************************/ +/* Flag Register - FR */ +/*****************************************************************************/ +#define FR_TX_FIFO_EMPTY 0x80 +#define FR_RX_FIFO_FULL 0x40 +#define FR_TX_FIFO_FULL 0x20 +#define FR_RX_FIFO_EMPTY 0x10 +#define FR_BUSY 0x08 +#define FR_CARRIER_DETECT 0x04 +#define FR_SET_READY 0x02 +#define FR_CLEAR_TO_SEND 0x01 + +/*****************************************************************************/ +/* Interrupt Identification Register - IIR */ +/*****************************************************************************/ +#define IIR_RX_TIME_OUT 0x08 +#define IIR_TX 0x04 +#define IIR_RX 0x02 +#define IIR_MODEM 0x01 + +void UartStdOutInit(void) +{ + /* Disable the serial port while setting the baud rate and word length. */ + UART0_CR = 0; + + /* Clear the receive status register. */ + UART0_ECR = 0; + + /* Set the correct baud rate and word length. */ + UART0_LCRL = LCRL_BAUD_115200; + UART0_LCRM = LCRM_BAUD_115200; + UART0_LCRH = LCRH_WORD_LENGTH_8; + + /* Explicitly disable FIFO's for char mode. */ + UART0_LCRH &= ~LCRH_FIFO_ENABLED; + + /* Enable UART0 (and RX/TX) without interrupts. */ + UART0_CR = CR_UART_ENABLE | CR_TX_INT_ENABLE | CR_RX_INT_ENABLE; +} + +unsigned char UartPutc(unsigned char ch) +{ + if (ch == '\n') { + (void) UartPutc('\r'); + } + while (UART0_FR & FR_TX_FIFO_FULL) + ; + UART0_DR = ch; + + return ch; +} + +unsigned char UartGetc(void) +{ + unsigned char c; + while (UART0_FR & FR_RX_FIFO_EMPTY) + ; + c = UART0_DR; + if (c == '\r') { + c = '\n'; + } + + return c; +} + +bool GetLine (char *lp, unsigned int len) +{ + unsigned int cnt = 0; + char c; + + do { + c = UartGetc(); + switch (c) { + case CNTLQ: /* ignore Control S/Q. */ + case CNTLS: + break; + case BACKSPACE: + case DEL: + if (cnt == 0) { + break; + } + cnt--; /* decrement count. */ + lp--; /* and line pointer. */ + UartPutc (0x08); /* echo backspace. */ + UartPutc (' '); + UartPutc (0x08); + fflush (stdout); + break; + case ESC: + case 0: + *lp = 0; /* ESC - stop editing line. */ + return false; + case CR: /* CR - done, stop editing line. */ + *lp = c; + lp++; /* increment line pointer. */ + cnt++; /* and count. */ + c = LF; + default: + UartPutc (*lp = c); /* echo and store character. */ + fflush (stdout); + lp++; /* increment line pointer. */ + cnt++; /* and count. */ + break; + } + } while (cnt < len - 2 && c != LF); /* check limit and CR. */ + *lp = 0; /* mark end of string. */ + return true; +} + +__attribute__((noreturn)) void UartEndSimulation(int code) +{ + UartPutc((char) 0x4); // End of simulation + UartPutc((char) code); // Exit code + while(1); +} diff --git a/source/application/hal/platforms/bare-metal/bsp/cmsis-device/cmsis.c b/source/application/hal/platforms/bare-metal/bsp/cmsis-device/cmsis.c new file mode 100644 index 0000000..c9cf53d --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/cmsis-device/cmsis.c @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cmsis.h" + +extern void *__Vectors; /* see irqs.c */ + +/*----------------------------------------------------------------------------*\ + * Define clocks (uses OSC1 ACLK) * +\*----------------------------------------------------------------------------*/ +#define __XTAL (25000000) /* Oscillator frequency */ +#define __SYSTEM_CLOCK (__XTAL) + +#define STR(x) #x +#define RESET_REG(n) __ASM volatile("MOV " STR(r##n) ", #0" : : : STR(r##n)) + +#if defined(CPU_CORTEX_M55) +#define CCR_DL (1 << 19) +#else +#error "Invalid CPU; This file only services Cortex-M55 CPUs" +#endif /* (CPU_CORTEX_M55) */ + +/*---------------------------------------------------------------------------- + System Core Clock Variable (Core Clock) + *----------------------------------------------------------------------------*/ +uint32_t SystemCoreClock = __SYSTEM_CLOCK; + + +/*---------------------------------------------------------------------------- + Clock functions + *----------------------------------------------------------------------------*/ +/** + * @brief Updates the SystemCoreClock variable with current core Clock + * retrieved from cpu registers. + */ +void SystemCoreClockUpdate(void) +{ + /* Update the SystemCoreClock variable */ + SystemCoreClock = __SYSTEM_CLOCK; +} + +uint32_t GetSystemCoreClock(void) +{ + return SystemCoreClock; +} + +/** + * @brief Setup the microcontroller system. + * Initialize the System. + **/ +void SystemInit(void) +{ +#if (defined (__FPU_USED) && (__FPU_USED == 1U)) || \ + (defined (__MVE_USED) && (__MVE_USED == 1U)) + SCB->CPACR |= ((3U << 10U*2U) | /* enable CP10 Full Access */ + (3U << 11U*2U) ); +#endif + + /* Initialise registers r0-r12 and LR(=r14) + * They must have a valid value before being potentially pushed to stack by + * C calling convention or by context saving in exception handling + */ + RESET_REG(0); + RESET_REG(1); + RESET_REG(2); + RESET_REG(3); + RESET_REG(4); + RESET_REG(5); + RESET_REG(6); + RESET_REG(7); + RESET_REG(8); + RESET_REG(9); + RESET_REG(10); + RESET_REG(11); + RESET_REG(12); + RESET_REG(14); + +#if defined (__VTOR_PRESENT) && (__VTOR_PRESENT == 1U) + SCB->VTOR = (uint32_t) &__Vectors; +#endif + + /* Enable hard, bus, mem and usage fault detection in SHCSR, bits 16-18. + * Enable stkof, bf, div_0_trp, unalign_trp and usersetm bits in CCR. + */ + SCB->SHCSR = ( + _VAL2FLD(SCB_SHCSR_USGFAULTENA, 1) | + _VAL2FLD(SCB_SHCSR_BUSFAULTENA, 1) | + _VAL2FLD(SCB_SHCSR_MEMFAULTENA, 1)); + + SCB->CCR = (_VAL2FLD(SCB_CCR_USERSETMPEND, 1) | + _VAL2FLD(SCB_CCR_DIV_0_TRP, 1) | + _VAL2FLD(SCB_CCR_BFHFNMIGN, 1) | + _VAL2FLD(SCB_CCR_STKOFHFNMIGN, 1)); +#ifdef UNALIGNED_SUPPORT_DISABLE + SCB->CCR |= _VAL2FLD(SCB_CCR_UNALIGN_TRP, 1); +#endif + + SCB->CCR |= CCR_DL; + + /* Reset pipeline. */ + __DSB(); + __ISB(); + +#ifdef UNALIGNED_SUPPORT_DISABLE + SCB->CCR |= SCB_CCR_UNALIGN_TRP_Msk; +#endif + + SystemCoreClock = __SYSTEM_CLOCK; +} diff --git a/source/application/hal/platforms/bare-metal/bsp/cmsis-device/include/cmsis.h b/source/application/hal/platforms/bare-metal/bsp/cmsis-device/include/cmsis.h new file mode 100644 index 0000000..969db15 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/cmsis-device/include/cmsis.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BAREMETAL_CMSIS_H +#define BAREMETAL_CMSIS_H + +#include "ARMCM55.h" /* Cortex M system header file from CMSIS. */ +#include "irqs.h" /* Interrupt definitions file. */ + +/* Addition to template functions should be mentioned here. */ + +/** + * @brief Gets the internal processor clock. + * @return Clock frequency as unsigned 32 bit value. + **/ +uint32_t GetSystemCoreClock(void); + +#endif /* BAREMETAL_CMSIS_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/cmsis-device/include/irqs.h b/source/application/hal/platforms/bare-metal/bsp/cmsis-device/include/irqs.h new file mode 100644 index 0000000..0d8dec6 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/cmsis-device/include/irqs.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef IRQS_H +#define IRQS_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "peripheral_irqs.h" + +#include + +/* Interrupt handler function type. */ +typedef void (*const irq_vec_type)(void); + +/** + * @brief Reset interrupt handler and also, the starting + * point of the application. + **/ +extern void Reset_Handler(void); + +/** + * @brief Gets the system tick triggered cycle counter for the CPU. + * @return 64-bit counter value. + **/ +extern uint64_t Get_SysTick_Cycle_Count(void); + +/** + * @brief Initialises the system tick registers. + * @return Error code return from sys tick configuration function + * (0 = no error). + **/ +extern int Init_SysTick(void); + +#ifdef __cplusplus +} +#endif + +#endif /* IRQS_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/cmsis-device/irqs.c b/source/application/hal/platforms/bare-metal/bsp/cmsis-device/irqs.c new file mode 100644 index 0000000..c6f54b1 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/cmsis-device/irqs.c @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef __cplusplus +extern "C" +{ +#endif + +#include "irqs.h" +#include "cmsis.h" + +#include + +static uint64_t cpu_cycle_count = 0; + +/** + * @brief Dump core registers on stdout + */ +static void LogCoreCPURegisters(void) +{ + printf("CTRL : 0x%08x\n", __get_CONTROL()); + printf("IPSR : 0x%08x\n", __get_IPSR()); + printf("APSR : 0x%08x\n", __get_APSR()); + printf("xPSR : 0x%08x\n", __get_xPSR()); + printf("PSP : 0x%08x\n", __get_PSP()); + printf("MSP : 0x%08x\n", __get_MSP()); + printf("PRIMASK : 0x%08x\n", __get_PRIMASK()); + printf("BASEPRI : 0x%08x\n", __get_BASEPRI()); + printf("FAULTMSK: 0x%08x\n", __get_FAULTMASK()); + printf("PC : 0x%08x\n", __current_pc()); +} + +/** + * @brief Default interrupt handler - an infinite loop. + **/ +__attribute__((noreturn)) static void DefaultHandler(void) +{ + LogCoreCPURegisters(); + while (1) { + /* Without the following line, armclang may optimize away the + * infinite loop because it'd be without side effects and thus + * undefined behaviour. */ + __ASM volatile(""); + } +} + +#define DEFAULT_HANDLER_CALL(type) \ + do { \ + printf("\n%s caught by function %s\n", \ + type, __FUNCTION__); \ + DefaultHandler(); \ + } while (0) + +#define DEFAULT_ERROR_HANDLER_CALL() \ + DEFAULT_HANDLER_CALL("Exception") + +#define DEFAULT_IRQ_HANDLER_CALL() \ + DEFAULT_HANDLER_CALL("Interrupt") + +/** + * Dummy Exception Handlers for core interrupts. + * + * Weak definitions provided to be used if the user chooses not + * to override them. + **/ + +/** + * @brief Non maskable interrupt handler. + **/ + __attribute__((weak)) void NMI_Handler(void) +{ + DEFAULT_ERROR_HANDLER_CALL(); +} + +/** + * @brief Hardfault interrupt handler. + **/ + __attribute__((weak)) void HardFault_Handler(void) +{ + DEFAULT_ERROR_HANDLER_CALL(); +} + +/** + * @brief Memory management interrupt handler. + **/ +__attribute__((weak)) void MemManage_Handler(void) +{ + DEFAULT_IRQ_HANDLER_CALL(); +} + +/** + * @brief Bus fault interrupt handler. + **/ +__attribute__((weak)) void BusFault_Handler(void) +{ + DEFAULT_ERROR_HANDLER_CALL(); +} + +/** + * @brief Usage fault interrupt handler. + **/ +__attribute__((weak)) void UsageFault_Handler(void) +{ + DEFAULT_ERROR_HANDLER_CALL(); +} + +/** + * @brief Secure access fault interrupt handler. + **/ +__attribute__((weak)) void SecureFault_Handler(void) +{ + DEFAULT_ERROR_HANDLER_CALL(); +} + +/** + * @brief Supervisor call interrupt handler. + **/ +__attribute__((weak)) void SVC_Handler(void) +{ + DEFAULT_IRQ_HANDLER_CALL(); +} + +/** + * @brief Debug monitor interrupt handler. + **/ +__attribute__((weak)) void DebugMon_Handler(void) +{ + DEFAULT_IRQ_HANDLER_CALL(); +} + +/** + * @brief Pending SV call interrupt handler. + */ +__attribute__((weak)) void PendSV_Handler(void) +{ + DEFAULT_IRQ_HANDLER_CALL(); +} + +/** + * @brief System tick interrupt handler. + **/ +void SysTick_Handler(void) +{ + /* Increment the cycle counter based on load value. */ + cpu_cycle_count += SysTick->LOAD + 1; +} + +uint64_t Get_SysTick_Cycle_Count(void) +{ + uint32_t systick_val; + + NVIC_DisableIRQ(SysTick_IRQn); + systick_val = SysTick->VAL & SysTick_VAL_CURRENT_Msk; + NVIC_EnableIRQ(SysTick_IRQn); + + return cpu_cycle_count + (SysTick->LOAD - systick_val); +} + + +/** + * These symbols are provided by the ARM lib - needs the stack and heap + * regions in the scatter file. + */ +extern void Image$$ARM_LIB_STACK$$ZI$$Base(); +extern void Image$$ARM_LIB_STACK$$ZI$$Limit(); +extern void Image$$ARM_LIB_HEAP$$ZI$$Base(); +extern void Image$$ARM_LIB_HEAP$$ZI$$Limit(); +extern __attribute__((noreturn)) void __main(); + +__attribute__((naked, used)) void __user_setup_stackheap() +{ + __ASM volatile("LDR r0, =Image$$ARM_LIB_HEAP$$ZI$$Base"); + __ASM volatile("LDR r1, =Image$$ARM_LIB_STACK$$ZI$$Limit"); + __ASM volatile("LDR r2, =Image$$ARM_LIB_HEAP$$ZI$$Limit"); + __ASM volatile("LDR r3, =Image$$ARM_LIB_STACK$$ZI$$Base"); + __ASM volatile("bx lr"); +} + +/** + * Interrupt vector table. + */ +irq_vec_type __Vectors[] __attribute__((section("RESET"), used)) = { + &Image$$ARM_LIB_STACK$$ZI$$Limit, /* 0 Initial SP */ + &Reset_Handler , /* 1 Initial PC, set to entry point */ + + &NMI_Handler , /* 2 (-14) NMI Handler */ + &HardFault_Handler , /* 3 (-13) Hard Fault Handler */ + &MemManage_Handler , /* 4 (-12) MPU Fault Handler */ + &BusFault_Handler , /* 5 (-11) Bus Fault Handler */ + &UsageFault_Handler , /* 6 (-10) Usage Fault Handler */ + &SecureFault_Handler, /* 7 ( -9) Secure Fault Handler */ + 0 , /* 8 ( -8) Reserved */ + 0 , /* 9 ( -7) Reserved */ + 0 , /* 10 ( -6) Reserved */ + &SVC_Handler , /* 11 ( -5) SVCall Handler */ + &DebugMon_Handler , /* 12 ( -4) Debug Monitor Handler */ + 0 , /* 13 ( -3) Reserved */ + &PendSV_Handler , /* 14 ( -2) PendSV Handler */ + &SysTick_Handler , /* 15 ( -1) SysTick Handler */ + + /* External sources to be populated by user. */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0 - 16 */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 16 - 32 */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 32 - 48 */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 48 - 64 */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 64 - 80 */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 80 - 96 */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 96 - 112 */ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 112 - 128 */ +}; + +int Init_SysTick(void) +{ + const uint32_t ticks_10ms = GetSystemCoreClock()/100 + 1; + int err = 0; + + /* Reset CPU cycle count value. */ + cpu_cycle_count = 0; + + /* Changing configuration for sys tick => guard from being + * interrupted. */ + NVIC_DisableIRQ(SysTick_IRQn); + + /* SysTick init - this will enable interrupt too. */ + err = SysTick_Config(ticks_10ms); + + /* Enable interrupt again. */ + NVIC_EnableIRQ(SysTick_IRQn); + + return err; +} + +/* Reset handler - starting point of our application. */ +__attribute__((used)) void Reset_Handler(void) +{ + /* Initialise system. */ + SystemInit(); + + /* Configure the system tick. */ + Init_SysTick(); + + /* libcxx supplied entry point. */ + __main(); +} + +#ifdef __cplusplus +} +#endif diff --git a/source/application/hal/platforms/bare-metal/bsp/include/bsp.h b/source/application/hal/platforms/bare-metal/bsp/include/bsp.h new file mode 100644 index 0000000..fbe1ff6 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/include/bsp.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BSP_H +#define BSP_H + +/* Core modules - these are common */ +#include "bsp_core_log.h" /* Logging related helpers. */ +#include "uart_stdout.h" /* stdout over UART. */ + +#if defined(MPS3_PLATFORM) /* If running on MPS3 platform. */ + +#include "smm_mps3.h" /* Mem map for MPS3 peripherals. */ +#include "glcd_mps3.h" /* LCD functions. */ +#include "timer_mps3.h" /* Timer functions. */ +#include "device_mps3.h" /* FPGA level definitions and functions. */ + +#else /* MPS3_PLATFORM */ + +#include "stubs_fvp.h" /* Stubs for FVP. */ +#include "timer_fvp.h" /* Timer API for FVP. */ + +#endif /* MPS3_PLATFORM */ + +#endif /* BSP_H */ diff --git a/source/application/hal/platforms/bare-metal/bsp/mem_layout/mps3-sse-200.sct b/source/application/hal/platforms/bare-metal/bsp/mem_layout/mps3-sse-200.sct new file mode 100644 index 0000000..293193e --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/mem_layout/mps3-sse-200.sct @@ -0,0 +1,102 @@ +; Copyright (c) 2021 Arm Limited. All rights reserved. +; SPDX-License-Identifier: Apache-2.0 +; +; Licensed under the Apache License, Version 2.0 (the "License"); +; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; http://www.apache.org/licenses/LICENSE-2.0 +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, +; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +; See the License for the specific language governing permissions and +; limitations under the License. + +; ************************************************************* +; *** Scatter-Loading Description File *** +; ************************************************************* +; +; Sections used: +;--------------------------------------------------------- +; | Start | End | Size | Remarks | +;-|-------------|-------------|-------------|------------| +; | 0x0000_0000 | 0x0010_0000 | 0x0010_0000 | ITCM (RO) | +; | 0x0010_0000 | 0x0030_0000 | 0x0020_0000 | BRAM (RW) | +; | 0x2000_0000 | 0x2040_0000 | 0x0040_0000 | DTCM (RW) | +; | 0x6000_0000 | 0x6200_0000 | 0x0200_0000 | DRAM (RW) | +;-|-------------|-------------|-------------|------------| +; ITCM is aliased at 0x1000_0000 (single bank) +; BRAM is aliased at 0x1010_0000 +; DTCM is aliased at 0x3000_0000 (four banks of 1MiB each) +; DRAM is aliased at 0x7000_0000 (section is 256MiB) +; +; Note: Ethos-U55 can only access DRAM and BRAM sections +;--------------------------------------------------------- +; First load region +;--------------------------------------------------------- +LOAD_REGION_0 0x00000000 0x00100000 +{ + ;----------------------------------------------------- + ; First part of code mem - 1MiB + ;----------------------------------------------------- + itcm.bin 0x00000000 0x00100000 + { + *.o (RESET, +First) + * (InRoot$$Sections) + .ANY (+RO) + } + + ;----------------------------------------------------- + ; Code memory's 2MiB - reserved for activation buffers + ; Make sure this is uninitialised. + ;----------------------------------------------------- + bram.bin 0x00100000 UNINIT 0x00200000 + { + ; activation buffers a.k.a tensor arena + *.o (.bss.NoInit.activation_buf) + } + + ;----------------------------------------------------- + ; 1MiB bank is used for any other RW or ZI data + ; Note: this region is internal to the Cortex-M CPU + ;----------------------------------------------------- + dtcm.bin 0x20000000 0x00100000 + { + .ANY(+RW +ZI) + } + + ;----------------------------------------------------- + ; 128kiB of stack space within SRAM region + ;----------------------------------------------------- + ARM_LIB_STACK 0x20100000 EMPTY ALIGN 8 0x00020000 + {} + + ;----------------------------------------------------- + ; 2MiB of heap space within the SRAM region + ;----------------------------------------------------- + ARM_LIB_HEAP 0x20200000 EMPTY ALIGN 8 0x00200000 + {} +} + +;--------------------------------------------------------- +; Second load region +;--------------------------------------------------------- +LOAD_REGION_1 0x60000000 0x02000000 +{ + ;----------------------------------------------------- + ; 32 MiB of DRAM space for nn model and input vectors + ;----------------------------------------------------- + dram.bin 0x60000000 0x02000000 + { + ; nn model's baked in input matrices + *.o (ifm) + + ; nn model + *.o (nn_model) + + ; if the activation buffer (tensor arena) doesn't + ; fit in the SRAM region, we accommodate it here + *.o (activation_buf) + } +} diff --git a/source/application/hal/platforms/bare-metal/bsp/mem_layout/mps3-sse-300.sct b/source/application/hal/platforms/bare-metal/bsp/mem_layout/mps3-sse-300.sct new file mode 100644 index 0000000..327d511 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/mem_layout/mps3-sse-300.sct @@ -0,0 +1,118 @@ +; Copyright (c) 2021 Arm Limited. All rights reserved. +; SPDX-License-Identifier: Apache-2.0 +; +; Licensed under the Apache License, Version 2.0 (the "License"); +; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; http://www.apache.org/licenses/LICENSE-2.0 +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, +; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +; See the License for the specific language governing permissions and +; limitations under the License. + +; ************************************************************* +; *** Scatter-Loading Description File *** +; ************************************************************* +; Please see docs/sections/appendix.md for memory mapping information. +; +; Note: Ethos-U55 can access BRAM, internal SRAM and the DDR sections => activation buffers and +; the model should only be placed in those regions. +; +;--------------------------------------------------------- +; First load region (ITCM) +;--------------------------------------------------------- +LOAD_REGION_0 0x00000000 0x00080000 +{ + ;----------------------------------------------------- + ; First part of code mem - 512kiB + ;----------------------------------------------------- + itcm.bin 0x00000000 0x00080000 + { + *.o (RESET, +First) + * (InRoot$$Sections) + + ; Essentially only RO-CODE, RO-DATA is in a + ; different region. + .ANY (+RO) + } + + ;----------------------------------------------------- + ; 128kiB of 512kiB DTCM is used for any other RW or ZI + ; data. Note: this region is internal to the Cortex-M + ; CPU. + ;----------------------------------------------------- + dtcm.bin 0x20000000 0x00020000 + { + ; Any R/W and/or zero initialised data + .ANY(+RW +ZI) + } + + ;----------------------------------------------------- + ; 384kiB of stack space within the DTCM region. See + ; `dtcm.bin` for the first section. Note: by virtue of + ; being part of DTCM, this region is only accessible + ; from Cortex-M55. + ;----------------------------------------------------- + ARM_LIB_STACK 0x20020000 EMPTY ALIGN 8 0x00060000 + {} + + ;----------------------------------------------------- + ; SSE-300's internal SRAM of 4MiB - reserved for + ; activation buffers. + ; This region should have 3 cycle read latency from + ; both Cortex-M55 and Ethos-U55 + ;----------------------------------------------------- + isram.bin 0x31000000 UNINIT ALIGN 16 0x00400000 + { + ; activation buffers a.k.a tensor arena + *.o (.bss.NoInit.activation_buf) + } +} + +;--------------------------------------------------------- +; Second load region (DDR) +;--------------------------------------------------------- +LOAD_REGION_1 0x70000000 0x02000000 +{ + ;----------------------------------------------------- + ; 32 MiB of DRAM space for neural network model, + ; input vectors and labels. If the activation buffer + ; size required by the network is bigger than the + ; SRAM size available, it is accommodated here. + ;----------------------------------------------------- + dram.bin 0x70000000 ALIGN 16 0x02000000 + { + ; nn model's baked in input matrices + *.o (ifm) + + ; nn model + *.o (nn_model) + + ; labels + *.o (labels) + + ; if the activation buffer (tensor arena) doesn't + ; fit in the SRAM region, we accommodate it here + *.o (activation_buf) + } + + ;----------------------------------------------------- + ; First 256kiB of BRAM (FPGA SRAM) used for RO data. + ; Note: Total BRAM size available is 2MiB. + ;----------------------------------------------------- + bram.bin 0x11000000 ALIGN 8 0x00040000 + { + ; RO data (incl. unwinding tables for debugging) + .ANY (+RO-DATA) + } + + ;----------------------------------------------------- + ; Remaining part of the 2MiB BRAM used as heap space. + ; 0x00200000 - 0x00040000 = 0x001C0000 (1.75 MiB) + ;----------------------------------------------------- + ARM_LIB_HEAP 0x11040000 EMPTY ALIGN 8 0x001C0000 + {} +} diff --git a/source/application/hal/platforms/bare-metal/bsp/mem_layout/simple_platform.sct b/source/application/hal/platforms/bare-metal/bsp/mem_layout/simple_platform.sct new file mode 100644 index 0000000..a1ffb49 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/bsp/mem_layout/simple_platform.sct @@ -0,0 +1,102 @@ +; Copyright (c) 2021 Arm Limited. All rights reserved. +; SPDX-License-Identifier: Apache-2.0 +; +; Licensed under the Apache License, Version 2.0 (the "License"); +; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; http://www.apache.org/licenses/LICENSE-2.0 +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, +; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +; See the License for the specific language governing permissions and +; limitations under the License. + +; ************************************************************* +; *** Scatter-Loading Description File *** +; ************************************************************* +; +;--------------------------------------------------------- +; First load region (ITCM) +;--------------------------------------------------------- +LOAD_REGION_0 0x00000000 0x00080000 +{ + ;----------------------------------------------------- + ; First part of code mem - 512kiB + ;----------------------------------------------------- + itcm.bin 0x00000000 0x00080000 + { + *.o (RESET, +First) + * (InRoot$$Sections) + + ; Essentially only RO-CODE, RO-DATA is in a + ; different region. + .ANY (+RO) + } + + ;----------------------------------------------------- + ; BRAM or FPGA data SRAM region worth 2MiB + ;----------------------------------------------------- + bram.bin 0x11000000 UNINIT ALIGN 16 0x00200000 + { + ; activation buffers a.k.a tensor arena + *.o (.bss.NoInit.activation_buf) + } + + ;----------------------------------------------------- + ; 128kiB of 512kiB bank is used for any other RW or ZI + ; data. Note: this region is internal to the Cortex-M + ; CPU + ;----------------------------------------------------- + dtcm.bin 0x20000000 0x00020000 + { + .ANY(+RW +ZI) + } + + ;----------------------------------------------------- + ; 128kiB of stack space within the DTCM region + ;----------------------------------------------------- + ARM_LIB_STACK 0x20020000 EMPTY ALIGN 8 0x00020000 + {} + + ;----------------------------------------------------- + ; 256kiB of heap space within the DTCM region + ;----------------------------------------------------- + ARM_LIB_HEAP 0x20040000 EMPTY ALIGN 8 0x00040000 + {} +} + +;--------------------------------------------------------- +; Second load region (DDR) +;--------------------------------------------------------- +LOAD_REGION_1 0x70000000 0x02000000 +{ + ;----------------------------------------------------- + ; 32 MiB of DRAM space for nn model and input vectors + ;----------------------------------------------------- + dram.bin 0x70000000 ALIGN 16 0x02000000 + { + ; nn model's baked in input matrices + *.o (ifm) + + ; nn model + *.o (nn_model) + + ; if the activation buffer (tensor arena) doesn't + ; fit in the SRAM region, we accommodate it here + *.o (activation_buf) + } + + ;----------------------------------------------------- + ; SSE-300's internal SRAM of 2MiB - reserved for + ; activation buffers. + ; This region should have 3 cycle read latency from + ; both Cortex-M55 and Ethos-U55 + ;----------------------------------------------------- + isram.bin 0x31000000 0x00080000 + { + ; RO data (incl. unwinding tables for debugging) + .ANY (+RO-DATA) + } +} diff --git a/source/application/hal/platforms/bare-metal/data_acquisition/data_acq.c b/source/application/hal/platforms/bare-metal/data_acquisition/data_acq.c new file mode 100644 index 0000000..1e40b02 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/data_acquisition/data_acq.c @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "data_acq.h" + +#include "bsp.h" + +#include +#include +#include + +/** + * @brief Get the user input from USART. + * @param[out] user_input String read from the UART block. + * @param[in] size String read length. + * @return 0 if successful, error code otherwise. + **/ +static int get_uart_user_input(char* user_input, int size) +{ + if (true != GetLine(user_input, size - 1)) { + printf_err("invalid input\n"); + return 1; + } + return 0; +} + +int data_acq_channel_init(data_acq_module* module) +{ + assert(module); + + /* UART should have been initialised with low level initialisation + * routines. */ + module->system_init = NULL; + + strncpy(module->system_name, "UART", sizeof(module->system_name)); + module->get_input = get_uart_user_input; + module->inited = 1; + + return !(module->inited); +} + +int data_acq_channel_release(data_acq_module* module) +{ + assert(module); + module->inited = 0; + module->get_input = NULL; + return 0; +} diff --git a/source/application/hal/platforms/bare-metal/data_presentation/data_psn.c b/source/application/hal/platforms/bare-metal/data_presentation/data_psn.c new file mode 100644 index 0000000..474d552 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/data_presentation/data_psn.c @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "data_psn.h" + +#include "bsp.h" +#include "lcd_img.h" + +#include +#include + +int data_psn_system_init(data_psn_module* module) +{ + assert(module); + + /* LCD output supported. */ + module->system_init = lcd_init; + module->present_data_image = lcd_display_image; + module->present_data_text = lcd_display_text; + module->present_box = lcd_display_box; + module->set_text_color = lcd_set_text_color; + module->clear = lcd_clear; + strncpy(module->system_name, "lcd", sizeof(module->system_name)); + module->inited = !module->system_init(); + return !module->inited; +} + +int data_psn_system_release(data_psn_module* module) +{ + assert(module); + module->inited = 0; + return 0; +} diff --git a/source/application/hal/platforms/bare-metal/data_presentation/lcd/include/lcd_img.h b/source/application/hal/platforms/bare-metal/data_presentation/lcd/include/lcd_img.h new file mode 100644 index 0000000..e4ad791 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/data_presentation/lcd/include/lcd_img.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LCD_IMG_H +#define LCD_IMG_H + +#include +#include +#include + +/** + * @brief Initialise the LCD + * @return 0 if successful, error code otherwise. + **/ +int lcd_init(void); + +/** + * @brief Display a given image on the LCD. This allows displaying 8 bit + * single or multi-channel images on the LCD. + * @param[in] data Pointer to start of the image. + * @param[in] width Width of this image. + * @param[in] height Image height. + * @param[in] channels Number of channels. + * @param[in] pos_x Screen position x co-ordinate. + * @param[in] pos_y Screen position y co-ordinate. + * @param[in] downsample_factor Factor by which the image needs to be + * downsampled. + * @return 0 if successful, non-zero otherwise. + **/ +int lcd_display_image(uint8_t* data, const uint32_t width, + const uint32_t height, const uint32_t channels, + const uint32_t pos_x, const uint32_t pos_y, + const uint32_t downsample_factor); + +/** + * @brief Display a given image on the LCD. This allows displaying 8 bit + * single or multi-channel images on the LCD. + * @param[in] str Pointer to a null terminated string. + * @param[in] str_sz Length of the string. + * @param[in] pos_x Screen position x co-ordinate. + * @param[in] pos_y Screen position y co-ordinate. + * @param[in] allow_multiple_lines The function will try and spread + * the string into multiple lines if + * they don't fit in one. + * @return 0 if successful, non-zero otherwise. + **/ +int lcd_display_text(const char* str, const size_t str_sz, + const uint32_t pos_x, const uint32_t pos_y, + const bool allow_multiple_lines); + +/** + * @brief Display a box with given color on LCD. + * @param[in] pos_x Screen position x co-ordinate. + * @param[in] pos_y Screen position y co-ordinate. + * @param[in] width Width. + * @param[in] height Height. + * @param[in] color Fill color. + * @return 0 if successful, non-zero otherwise. + **/ +int lcd_display_box(const uint32_t pos_x, const uint32_t pos_y, + const uint32_t width, const uint32_t height, const uint16_t color); + +/** + * @brief Clear LCD. + * @param[in] color Fill color. + * @return 0 if successful, non-zero otherwise. + **/ +int lcd_clear(const uint16_t color); + +/** + * @brief Set text color. + * @param[in] color Fill color. + * @return 0 if successful, non-zero otherwise. + **/ +int lcd_set_text_color(const uint16_t color); + +#endif /* LCD_IMG_H */ diff --git a/source/application/hal/platforms/bare-metal/data_presentation/lcd/lcd_img.c b/source/application/hal/platforms/bare-metal/data_presentation/lcd/lcd_img.c new file mode 100644 index 0000000..75f58fd --- /dev/null +++ b/source/application/hal/platforms/bare-metal/data_presentation/lcd/lcd_img.c @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lcd_img.h" + +#include "bsp.h" + +#include +#include + +static int show_title(void) +{ + char title[128]; + int status = 0; + + /* LCD title string */ +#if defined(CPU_CORTEX_M55) + const char* cpu_name = "Arm Cortex-M55"; +#else /* defined(CPU_CORTEX_M55) */ + const char* cpu_name = "Arm CPU"; +#endif /* defined(CPU_CORTEX_M55) */ + + lcd_set_text_color(White); + + /* First line */ + snprintf(title, sizeof(title), "Arm ML embedded code samples"); + + if (0 != (status = lcd_display_text( + title, strlen(title), 10, 0, false))) { + return status; + } + + /* Second line */ +#if defined (ARM_NPU) + snprintf(title, sizeof(title), "%s + Arm Ethos-U55 NPU", cpu_name); +#else /* defined (ARM_NPU) */ + snprintf(title, sizeof(title), "%s", cpu_name); +#endif /* defined (ARM_NPU) */ + + return lcd_display_text(title, strlen(title), 10, 20, false); +} + +int lcd_init(void) +{ + GLCD_Initialize(); + GLCD_Clear(Black); + return show_title(); +} + +int lcd_display_image(uint8_t* data, const uint32_t width, + const uint32_t height, const uint32_t channels, + const uint32_t pos_x, const uint32_t pos_y, + const uint32_t downsample_factor) +{ + /* Sanity checks */ + assert(data); + if ((pos_x + width/downsample_factor > GLCD_WIDTH) || + (pos_y + height/downsample_factor > GLCD_HEIGHT)) { + printf_err("Invalid image size for given location!\n"); + return 1; + } + + if (1 == channels || 3 == channels) { + GLCD_Image(data, width, height, channels, pos_x, pos_y, + downsample_factor); + } else { + printf_err("Only single and three channel images are supported!\n"); + return 1; + } + + return 0; +} + +int lcd_display_text(const char* str, const size_t str_sz, + const uint32_t pos_x, const uint32_t pos_y, + const bool allow_multiple_lines) +{ + /* We use a font 0 which is 9x15. */ + const uint32_t x_span = 9; /* Each character is this 9 pixels "wide". */ + const uint32_t y_span = 15; /* Each character is this 15 pixels "high". */ + + if (str_sz == 0) { + return 1; + } + + /* If not within the LCD bounds, return error. */ + if (pos_x + x_span > GLCD_WIDTH || pos_y + y_span > GLCD_HEIGHT) { + return 1; + } else { + const unsigned char font_idx = 0; /* We are using the custom font = 0 */ + + const uint32_t col = pos_x/x_span; + const uint32_t max_cols = GLCD_WIDTH/x_span - 1; + const uint32_t max_lines = GLCD_HEIGHT/y_span - 1; + + uint32_t i = 0; + uint32_t current_line = pos_y/y_span; + uint32_t current_col = col; + + /* Display the string on the LCD. */ + for (i = 0; i < str_sz; ++i) { + + if (allow_multiple_lines) { + + /* If the next character won't fit. */ + if (current_col > max_cols) { + current_col = col; + + /* If the next line won't fit. */ + if (++current_line > max_lines) { + return 1; + } + } + } + + GLCD_DisplayChar(current_line, current_col++, font_idx, str[i]); + } + } + return 0; +} + +int lcd_display_box(const uint32_t pos_x, const uint32_t pos_y, + const uint32_t width, const uint32_t height, const uint16_t color) +{ + /* If not within the LCD bounds, return error. */ + if (pos_x > GLCD_WIDTH || pos_y > GLCD_HEIGHT) { + return 1; + } + else { + GLCD_Box(pos_x, pos_y, width, height, color); + } + return 0; +} + +int lcd_clear(const uint16_t color) +{ + GLCD_Clear(color); + GLCD_SetTextColor(White); + return show_title(); +} + +int lcd_set_text_color(const uint16_t color) +{ + GLCD_SetTextColor(color); + return 0; +} diff --git a/source/application/hal/platforms/bare-metal/timer/baremetal_timer.c b/source/application/hal/platforms/bare-metal/timer/baremetal_timer.c new file mode 100644 index 0000000..7257c1d --- /dev/null +++ b/source/application/hal/platforms/bare-metal/timer/baremetal_timer.c @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "bsp.h" +#include "timer.h" + +#include +#include + +#if defined (ARM_NPU) + +#include "pmu_ethosu.h" + +/** + * @brief Initialises the PMU and enables the cycle counter. + **/ +static void _init_ethosu_cyclecounter(void); + +/** + * @brief Gets the difference of total NPU cycle counts. + * (includes active and idle) + * @param[in] st Pointer to time_counter value at start time. + * @param[in] end Pointer to time_counter value at end. + * @return Total NPU cycle counts difference between the arguments expressed + * as unsigned 64 bit integer. + **/ +static uint64_t bm_get_npu_total_cycle_diff(time_counter *st, + time_counter *end); + +/** + * @brief Gets the difference in active NPU cycle counts. + * @param[in] st Pointer to time_counter value at start time. + * @param[in] end Pointer to time_counter value at end. + * @return Active NPU cycle counts difference between the arguments expressed + * as unsigned 64 bit integer. + **/ +static uint64_t bm_get_npu_active_cycle_diff(time_counter *st, + time_counter *end); + +#endif /* defined (ARM_NPU) */ + +#if defined(MPS3_PLATFORM) +/** + * @brief Wrapper for getting milliseconds duration between time counters + * @param[in] st Pointer to time_counter value at start time. + * @param[in] end Pointer to time_counter value at end. + * @return Difference in milliseconds between given time counters. + **/ +static time_t bm_get_duration_ms(time_counter *st, time_counter *end); + +/** + * @brief Wrapper for getting microseconds duration between time counters + * @param[in] st Pointer to time_counter value at start time. + * @param[in] end Pointer to time_counter value at end. + * @return Difference in microseconds between given time counters. + **/ +static time_t bm_get_duration_us(time_counter *st, time_counter *end); +#endif /* defined(MPS3_PLATFORM) */ + +/** + * @brief Wrapper for resetting timer. + **/ +static void bm_timer_reset(void); + +/** + * @brief Wrapper for getting the current timer counter. + * @return Current time counter value. + **/ +static time_counter bm_get_time_counter(void); + +/** + * @brief Wrapper for profiler start. + * @return Current profiler start timer counter. + **/ +static time_counter bm_start_profiling(void); + +/** + * @brief Wrapper for profiler end. + * @return Current profiler end timer counter. + **/ +static time_counter bm_stop_profiling(void); + +/** + * @brief Wrapper for getting CPU cycle difference between time counters. + * @return CPU cycle difference between given time counters expressed + * as unsigned 32 bit integer. + **/ +static uint32_t bm_get_cpu_cycles_diff(time_counter *st, time_counter *end); + +/** + * @brief Initialiser for bare metal timer. + * @param[in] timer Platform timer to initialize. + **/ +void init_timer(platform_timer *timer) +{ + assert(timer); + memset(timer, 0, sizeof(*timer)); + + timer->reset = bm_timer_reset; + timer->get_time_counter = bm_get_time_counter; + timer->start_profiling = bm_start_profiling; + timer->stop_profiling = bm_stop_profiling; + timer->get_cpu_cycle_diff = bm_get_cpu_cycles_diff; + timer->cap.cpu_cycles = 1; + +#if defined (MPS3_PLATFORM) + timer->cap.duration_ms = 1; + timer->cap.duration_us = 1; + timer->get_duration_ms = bm_get_duration_ms; + timer->get_duration_us = bm_get_duration_us; +#endif /* defined (MPS3_PLATFORM) */ + +#if defined (ARM_NPU) + /* We are capable of reporting npu cycle counts. */ + timer->cap.npu_cycles = 1; + timer->get_npu_total_cycle_diff = bm_get_npu_total_cycle_diff; + timer->get_npu_active_cycle_diff = bm_get_npu_active_cycle_diff; + _init_ethosu_cyclecounter(); +#endif /* defined (ARM_NPU) */ + + timer->reset(); + timer->inited = 1; +} + +#if defined (ARM_NPU) + +static void _reset_ethosu_counters(void) +{ + /* Reset all cycle and event counters. */ + ETHOSU_PMU_CYCCNT_Reset(); + ETHOSU_PMU_EVCNTR_ALL_Reset(); +} + +static void _init_ethosu_cyclecounter(void) +{ + /* Reset overflow status. */ + ETHOSU_PMU_Set_CNTR_OVS(ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CCNT_Msk); + + /* Set the counter #0 to count idle cycles. */ + ETHOSU_PMU_Set_EVTYPER(0, ETHOSU_PMU_NPU_IDLE); + + /* Enable PMU. */ + ETHOSU_PMU_Enable(); + + /* Enable counters for cycle and counter# 0. */ + ETHOSU_PMU_CNTR_Enable(ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CCNT_Msk); + + _reset_ethosu_counters(); +} + +static uint64_t bm_get_npu_total_cycle_diff(time_counter *st, time_counter *end) +{ + return end->npu_total_ccnt - st->npu_total_ccnt; +} + +static uint64_t bm_get_npu_active_cycle_diff(time_counter *st, time_counter *end) +{ + /* Check for overflow: The idle counter is 32 bit while the + total cycle count is 64 bit. */ + const uint32_t overflow_status = ETHOSU_PMU_Get_CNTR_OVS(); + + if (ETHOSU_PMU_CNT1_Msk & overflow_status) { + printf_err("EthosU PMU idle counter overflow.\n"); + return 0; + } + + /* Active NPU time = total time - idle time */ + return (bm_get_npu_total_cycle_diff(st, end) + + (uint64_t)(st->npu_idle_ccnt)) - (uint64_t)(end->npu_idle_ccnt); +} + +#endif /* defined (ARM_NPU) */ + +static void bm_timer_reset(void) +{ +#if defined (ARM_NPU) + _init_ethosu_cyclecounter(); +#endif /* defined (ARM_NPU) */ + + timer_reset(); +} + +static time_counter bm_get_time_counter(void) +{ + time_counter t = { + .counter = get_time_counter(), + +#if defined (ARM_NPU) + .npu_idle_ccnt = ETHOSU_PMU_Get_EVCNTR(0), + .npu_total_ccnt = ETHOSU_PMU_Get_CCNTR() +#endif /* defined (ARM_NPU) */ + + }; + +#if defined (ARM_NPU) + debug("NPU total cc: %llu; NPU idle cc: %u\n", + t.npu_total_ccnt, t.npu_idle_ccnt); +#endif /* defined (ARM_NPU) */ + + return t; +} + +static time_counter bm_start_profiling(void) +{ + start_cycle_counter(); + return bm_get_time_counter(); +} + +static time_counter bm_stop_profiling(void) +{ + stop_cycle_counter(); + return bm_get_time_counter(); +} + +static uint32_t bm_get_cpu_cycles_diff(time_counter *st, time_counter *end) +{ + return get_cycle_count_diff(&(st->counter), &(end->counter)); +} + +#if defined(MPS3_PLATFORM) +static time_t bm_get_duration_ms(time_counter *st, time_counter *end) +{ + return get_duration_milliseconds(&(st->counter), &(end->counter)); +} + +static time_t bm_get_duration_us(time_counter *st, time_counter *end) +{ + return get_duration_microseconds(&(st->counter), &(end->counter)); +} +#endif /* defined(MPS3_PLATFORM) */ diff --git a/source/application/hal/platforms/bare-metal/timer/include/baremetal_timer.h b/source/application/hal/platforms/bare-metal/timer/include/baremetal_timer.h new file mode 100644 index 0000000..c8fc32c --- /dev/null +++ b/source/application/hal/platforms/bare-metal/timer/include/baremetal_timer.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BAREMETAL_TIMER_H +#define BAREMETAL_TIMER_H + +#include +#include + +#if defined (MPS3_PLATFORM) + #include "timer_mps3.h" + typedef mps3_time_counter base_time_counter; +#else /* defined (MPS3_PLATFORM) */ + #include "timer_fvp.h" + typedef fvp_time_counter base_time_counter; +#endif /* defined (MPS3_PLATFORM) */ + +typedef struct bm_time_counter { + base_time_counter counter; + +#if defined (ARM_NPU) + uint64_t npu_total_ccnt; + uint32_t npu_idle_ccnt; +#endif /* ARM_NPU */ + +} time_counter; + +#endif /* BAREMETAL_TIMER_H */ diff --git a/source/application/hal/platforms/bare-metal/utils/include/system_init.h b/source/application/hal/platforms/bare-metal/utils/include/system_init.h new file mode 100644 index 0000000..84e0305 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/utils/include/system_init.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BAREMETAL_SYSTEM_INIT_H +#define BAREMETAL_SYSTEM_INIT_H + +#include "bsp.h" + +/** + * @brief Initialises the platform (MPS3 FPGA board or Fixed Virtual Platform) + * Updates the system core clock and initialises the UART. It also + * verifies that the Cortex-M CPU variant being used matches the expected + * value if running on MPS3. + * @return 0 if successful, error code otherwise. +*/ +int system_init(void); + +/** + * @brief Releases the platform (MPS3 FPGA board or Fixed Virtual Platform). + **/ +void system_release(void); + +/** + * @brief Return the name the platform (MPS3 FPGA board or Fixed Virtual Platform). + * @param[out] name Platform name string. + * @param[in] size Name string length. + **/ +void system_name(char* name, size_t size); + +#endif /* BAREMETAL_SYSTEM_INIT_H */ diff --git a/source/application/hal/platforms/bare-metal/utils/system_init.c b/source/application/hal/platforms/bare-metal/utils/system_init.c new file mode 100644 index 0000000..0a6a1b3 --- /dev/null +++ b/source/application/hal/platforms/bare-metal/utils/system_init.c @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "system_init.h" + +#include "uart_stdout.h" + +#include + +#if defined(MPS3_PLATFORM) +#define CREATE_MASK(msb, lsb) (((1U << ((msb) - (lsb) + 1)) - 1) << (lsb)) +#define MASK_BITS(arg, msb, lsb) ((arg) & CREATE_MASK(msb, lsb)) +#define EXTRACT_BITS(arg, msb, lsb) (MASK_BITS(arg, msb, lsb) >> (lsb)) +#endif /* MPS3_PLATFORM */ + +int system_init(void) +{ +#if defined(MPS3_PLATFORM) + uint32_t id = 0; + uint32_t fpgaid = 0; + uint32_t apnote = 0; + uint32_t rev = 0; + uint32_t aid = 0; + uint32_t fpga_clk = 0; + + /* Initialise the LEDs as the switches are */ + MPS3_FPGAIO->LED = MPS3_FPGAIO->SWITCHES & 0xFF; +#endif + + /* UART init - will enable valid use of printf (stdout + * re-directed at this UART (UART0) */ + UartStdOutInit(); + info("Processor internal clock: %u Hz\n", GetSystemCoreClock()); + +#if defined(MPS3_PLATFORM) + /* Get revision information from various registers */ + rev = MPS3_SCC->CFG_REG4; + fpgaid = MPS3_SCC->SCC_ID; + aid = MPS3_SCC->SCC_AID; + apnote = EXTRACT_BITS(fpgaid, 15, 4); + fpga_clk = GetMPS3CoreClock(); + + info("V2M-MPS3 revision %c\n\n", rev + 'A'); + info("Application Note AN%x, Revision %c\n", apnote, + EXTRACT_BITS(aid, 23, 20) + 'A'); + info("MPS3 build %d\n", EXTRACT_BITS(aid, 31, 24)); + info("MPS3 core clock has been set to: %d Hz\n", fpga_clk); + + /* Display CPU ID */ + id = SCB->CPUID; + info("CPU ID: 0x%08x\n", id); + + if(EXTRACT_BITS(id, 15, 8) == 0xD2) { + if (EXTRACT_BITS(id, 7, 4) == 2) { + info ("CPU: Cortex-M55 r%dp%d\n\n", + EXTRACT_BITS(id, 23, 20),EXTRACT_BITS(id, 3, 0)); +#if defined (CPU_CORTEX_M55) + /* CPU ID should be "0x_41_0f_d2_20" for Cortex-M55 */ + return 0; +#endif /* CPU_CORTEX_M55 */ + } else if (EXTRACT_BITS(id, 7, 4) == 1) { + info ("CPU: Cortex-M33 r%dp%d\n\n", + EXTRACT_BITS(id, 23, 20),EXTRACT_BITS(id, 3, 0)); +#if defined (CPU_CORTEX_M33) + return 0; +#endif /* CPU_CORTEX_M33 */ + } else if (EXTRACT_BITS(id, 7, 4) == 0) { + info ("CPU: Cortex-M23 r%dp%d\n\n", + EXTRACT_BITS(id, 23, 20),EXTRACT_BITS(id, 3, 0)); + } else { + info ("CPU: Cortex-M processor family"); + } + } else if (EXTRACT_BITS(id, 15, 8) == 0xC6) { + info ("CPU: Cortex-M%d+ r%dp%d\n\n", + EXTRACT_BITS(id, 7, 4), EXTRACT_BITS(id, 23, 20), + EXTRACT_BITS(id, 3, 0)); + } else { + info ("CPU: Cortex-M%d r%dp%d\n\n", + EXTRACT_BITS(id, 7, 4), EXTRACT_BITS(id, 23, 20), + EXTRACT_BITS(id, 3, 0)); + } +#else /* MPS3_PLATFORM */ + + info("ARM model environment ready..\n"); + return 0; +#endif /* MPS3_PLATFORM */ + + /* If the CPU is anything other than M33 or M55, we return 1 */ + printf_err("CPU mismatch!\n"); + return 1; +} + +void system_release(void) +{ + __disable_irq(); +} + +void system_name(char* name, size_t size) +{ +#if defined (MPS3_PLATFORM) + strncpy(name, "mps3-bare", size); +#else /* MPS3_PLATFORM */ + strncpy(name, "FVP", size); +#endif /* MPS3_PLATFORM */ +} \ No newline at end of file diff --git a/source/application/hal/platforms/native/data_acquisition/data_acq.c b/source/application/hal/platforms/native/data_acquisition/data_acq.c new file mode 100644 index 0000000..01f47fa --- /dev/null +++ b/source/application/hal/platforms/native/data_acquisition/data_acq.c @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "data_acq.h" + +#include +#include +#include + +/** + * @brief Initialize the acuisition. + * @return 0 if successful, error code otherwise. + **/ +static int acquisition_init(void) +{ + return 0; +} + +/** + * @brief Get the user input from stdin. + * @param[out] user_input String read from the stdin. + * @param[in,out] size String read length. + * @return 0 if successful, error code otherwise. + **/ +static int get_user_input(char* user_input, int size) +{ + fgets(user_input, size, stdin); + return 0; +} + +int data_acq_channel_init(data_acq_module *module) +{ + assert(module); + + module->system_init = acquisition_init; + module->get_input = get_user_input; + strncpy(module->system_name, "native", + sizeof(module->system_name)); + module->inited = !module->system_init(); + return !module->inited; +} + +int data_acq_channel_release(data_acq_module *module) +{ + assert(module); + module->inited = 0; + return 0; +} diff --git a/source/application/hal/platforms/native/data_presentation/data_psn.c b/source/application/hal/platforms/native/data_presentation/data_psn.c new file mode 100644 index 0000000..fe4bcfa --- /dev/null +++ b/source/application/hal/platforms/native/data_presentation/data_psn.c @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "data_psn.h" + +#include "log.h" + +#include +#include + +int data_psn_system_init(data_psn_module *module) +{ + assert(module); + + module->system_init = log_psn_init; + module->present_data_image = log_display_image; + module->present_data_text = log_display_text; + module->present_box = log_display_box_icon; + module->set_text_color = log_set_text_color; + module->clear = log_clear; + strncpy(module->system_name, "log_psn", sizeof(module->system_name)); + module->inited = !module->system_init(); + return !module->inited; +} + +int data_psn_system_release(data_psn_module *module) +{ + /* Nothing to do here! */ + assert(module); + module->inited = 0; + return 0; +} diff --git a/source/application/hal/platforms/native/data_presentation/log/include/log.h b/source/application/hal/platforms/native/data_presentation/log/include/log.h new file mode 100644 index 0000000..10cf303 --- /dev/null +++ b/source/application/hal/platforms/native/data_presentation/log/include/log.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NATIVE_LOG_H +#define NATIVE_LOG_H + +#include +#include +#include + +/** + * @brief Data presentation initialiser + **/ +int log_psn_init(void); + +/** + * @brief Log parameters for the image to be displayed. + * @param[in] data Image pointer. + * @param[in] width Image width. + * @param[in] height Image height. + * @param[in] channels Number of channels. + * @param[in] pos_x Screen position x co-ordinate. + * @param[in] pos_y Screen position y co-ordinate. + * @param[in] downsample_factor Factor by which the image needs to be + * down-sampled. + * @return 0 if successful, non-zero otherwise. + **/ + +int log_display_image(uint8_t* data, const uint32_t width, + const uint32_t height, const uint32_t channels, + const uint32_t pos_x, const uint32_t pos_y, + const uint32_t downsample_factor); + +/** + * @brief Log the parameters for text to be displayed. + * @param[in] str Pointer to a null terminated string. + * @param[in] str_sz Length of the string. + * @param[in] pos_x Screen position x co-ordinate. + * @param[in] pos_y Screen position y co-ordinate. + * @return 0 if successful, non-zero otherwise. + **/ +int log_display_text(const char* str, const size_t str_sz, + const uint32_t pos_x, const uint32_t pos_y, + const bool allow_multiple_lines); + +/** + * @brief Log parameters for the box to be displayed. + * @param[in] pos_x Screen position x co-ordinate. + * @param[in] pos_y Screen position y co-ordinate. + * @param[in] width Width. + * @param[in] height Height. + * @param[in] color Fill color. + * @return 0 if successful, non-zero otherwise. + **/ +int log_display_box_icon(const uint32_t pos_x, const uint32_t pos_y, + const uint32_t width, const uint32_t height, const uint16_t color); + +/** + * @brief Logs the colour with which the display + * needs to be cleared with. + * @param[in] color Fill color. + * @return 0 if successful, non-zero otherwise. + **/ +int log_clear(const uint16_t color); + +/** + * @brief Logs the text color to be set. + * @param[in] color Fill color. + * @return 0 if successful, non-zero otherwise. + **/ +int log_set_text_color (const uint16_t color); + +#endif /* NATIVE_LOG_H */ \ No newline at end of file diff --git a/source/application/hal/platforms/native/data_presentation/log/log.c b/source/application/hal/platforms/native/data_presentation/log/log.c new file mode 100644 index 0000000..48e8b95 --- /dev/null +++ b/source/application/hal/platforms/native/data_presentation/log/log.c @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "log.h" + +#include "dummy_log.h" + +#include + +int log_psn_init(void) +{ + return 0; +} + +int log_display_image(uint8_t* data, const uint32_t width, + const uint32_t height, const uint32_t channels, + const uint32_t pos_x, const uint32_t pos_y, + const uint32_t downsample_factor) +{ + info("Image details\n"); + info("Data: %p\n", data); + info("WxHxC: %dx%dx%d\n", width, height, channels); + info("Pos (x,y): (%d,%d)\n", pos_x, pos_y); + info("Downsampling factor: %u\n", downsample_factor); + return 0; +} + +int log_display_text(const char* str, const size_t str_sz, + const uint32_t pos_x, const uint32_t pos_y, + const bool allow_multiple_lines) +{ + UNUSED(allow_multiple_lines); + info("%s\n", str); + info("Text size: %lu, x: %d, y: %d\n", str_sz, pos_x, pos_y); + return 0; +} + + +int log_display_box_icon(const uint32_t pos_x, const uint32_t pos_y, + const uint32_t width, const uint32_t height, + const uint16_t color) +{ + info("Showing rectangular, width: %d, height: %d, color: %d, x: %d, y: %d\n", + width, height, color, pos_x, pos_y); + return 0; +} + +int log_clear(const uint16_t color) +{ + info("Clearing with color: %d\n", color); + return 0; +} + +int log_set_text_color (const uint16_t color) +{ + info("Setting text color: %d\n", color); + return 0; +} diff --git a/source/application/hal/platforms/native/timer/include/native_timer.h b/source/application/hal/platforms/native/timer/include/native_timer.h new file mode 100644 index 0000000..df7b493 --- /dev/null +++ b/source/application/hal/platforms/native/timer/include/native_timer.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TIMER_H +#define TIMER_H + +#include +#include + +/* Container for time struct */ +typedef struct _time_counter { + /* Current POSIX time in secs. */ + time_t current_secs; + /* Nanoseconds expired in current second. */ + time_t current_nsecs; +} time_counter; + +#endif /* TIMER_H */ \ No newline at end of file diff --git a/source/application/hal/platforms/native/timer/native_timer.cc b/source/application/hal/platforms/native/timer/native_timer.cc new file mode 100644 index 0000000..c115f4d --- /dev/null +++ b/source/application/hal/platforms/native/timer/native_timer.cc @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef __cplusplus +extern "C" { +#endif + +#include "timer.h" + +#include +#include +#include + +#define MILLISECONDS_IN_SECOND 1000 +#define MICROSECONDS_IN_SECOND 1000000 +#define NANOSECONDS_IN_MILLISECOND 1000000 +#define NANOSECONDS_IN_MICROSECOND 1000 + +/** + * @brief Gets the current time counter value. + * @return Counter value expressed in terms of time_counter struct. + **/ +static time_counter get_time_counter(void) +{ + struct timespec current_time{}; + clock_gettime(1, ¤t_time); + time_counter t = { + .current_secs = current_time.tv_sec, + .current_nsecs = current_time.tv_nsec + }; + + return t; +} + +/** + * @brief Gets the time duration elapsed between start and end. + * @param[in] start Pointer to time_counter value at start time. + * @param[in] end Pointer to time_counter value at end. + * @return Difference in milliseconds between the arguments expressed + * as unsigned 32 bit integer. + **/ +static time_t get_duration_milliseconds(time_counter *start, time_counter *end) +{ + /* Convert both parts of time struct to ms then add for complete time. */ + time_t seconds_part = + (end->current_secs - start->current_secs) * MILLISECONDS_IN_SECOND; + time_t nanoseconds_part = + (end->current_nsecs - start->current_nsecs) / NANOSECONDS_IN_MILLISECOND; + + return seconds_part + nanoseconds_part; +} + +/** + * @brief Gets the time duration elapsed between start and end. + * @param[in] start Pointer to time_counter value at start time. + * @param[in] end Pointer to time_counter value at end. + * @return Difference in microseconds between the arguments expressed + * as unsigned 32 bit integer. + **/ +static time_t get_duration_microseconds(time_counter *start, time_counter *end) +{ + /* Convert both parts of time struct to us then add for complete time. */ + time_t seconds_part = + (end->current_secs - start->current_secs) * MICROSECONDS_IN_SECOND; + time_t nanoseconds_part = + (end->current_nsecs - start->current_nsecs) / NANOSECONDS_IN_MICROSECOND; + + return seconds_part + nanoseconds_part; +} + +/** + * @brief Stub for timer reset. + **/ +void reset_timer() {} + +/** + * @brief Initialise the timer for this platform. + **/ +void init_timer(platform_timer *timer) +{ + assert(timer); + memset(timer, 0, sizeof(*timer)); + + timer->get_time_counter = get_time_counter; + timer->start_profiling = get_time_counter; + timer->stop_profiling = get_time_counter; + timer->get_duration_ms = get_duration_milliseconds; + timer->cap.duration_ms = 1; + timer->get_duration_us = get_duration_microseconds; + timer->cap.duration_us = 1; + timer->reset = reset_timer; + timer->inited = 1; +} + +#ifdef __cplusplus +} +#endif diff --git a/source/application/hal/platforms/native/utils/include/dummy_log.h b/source/application/hal/platforms/native/utils/include/dummy_log.h new file mode 100644 index 0000000..626436a --- /dev/null +++ b/source/application/hal/platforms/native/utils/include/dummy_log.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DUMMY_LOG_H +#define DUMMY_LOG_H + +#include + +#define LOG_LEVEL_TRACE 0 +#define LOG_LEVEL_DEBUG 1 +#define LOG_LEVEL_INFO 2 +#define LOG_LEVEL_WARN 3 +#define LOG_LEVEL_ERROR 4 + +#ifndef LOG_LEVEL +#define LOG_LEVEL LOG_LEVEL_INFO +#endif /*LOG_LEVEL*/ + +#define UNUSED(x) ((void)(x)) + +#if (LOG_LEVEL == LOG_LEVEL_TRACE) + #define trace(...) printf("[TRACE] "); printf(__VA_ARGS__) +#else + #define trace(...) +#endif /* LOG_LEVEL == LOG_LEVEL_TRACE */ + +#if (LOG_LEVEL <= LOG_LEVEL_DEBUG) + #define debug(...) printf("[DEBUG] "); printf(__VA_ARGS__) +#else + #define debug(...) +#endif /* LOG_LEVEL > LOG_LEVEL_TRACE */ + +#if (LOG_LEVEL <= LOG_LEVEL_INFO) + #define info(...) printf("[INFO] "); printf(__VA_ARGS__) +#else + #define info(...) +#endif /* LOG_LEVEL > LOG_LEVEL_DEBUG */ + +#if (LOG_LEVEL <= LOG_LEVEL_WARN) + #define warn(...) printf("[WARN] "); printf(__VA_ARGS__) +#else + #define warn(...) +#endif /* LOG_LEVEL > LOG_LEVEL_INFO */ + +#if (LOG_LEVEL <= LOG_LEVEL_ERROR) + #define printf_err(...) printf("[ERROR] "); printf(__VA_ARGS__) +#else + #define printf_err(...) +#endif /* LOG_LEVEL > LOG_LEVEL_INFO */ + +#endif /* DUMMY_LOG_H */ \ No newline at end of file diff --git a/source/application/hal/platforms/native/utils/include/system_init.h b/source/application/hal/platforms/native/utils/include/system_init.h new file mode 100644 index 0000000..80b1bb2 --- /dev/null +++ b/source/application/hal/platforms/native/utils/include/system_init.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NATIVE_SYSTEM_INIT_H +#define NATIVE_SYSTEM_INIT_H + +#include "dummy_log.h" + +/** + * @brief Platform initialisation for native platform. + **/ +int system_init(void); + +/** + * @brief Platform release for native platform. + **/ +void system_release(void); + +/** + * @brief Returns the name of the platform. + * @param[out] name Platform name string. + * @param[in] size Name string length. + */ +void system_name(char* name, size_t size); + +#endif /* NATIVE_SYSTEM_INIT_H */ diff --git a/source/application/hal/platforms/native/utils/system_init.c b/source/application/hal/platforms/native/utils/system_init.c new file mode 100644 index 0000000..8e0b768 --- /dev/null +++ b/source/application/hal/platforms/native/utils/system_init.c @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "system_init.h" + +#include + +int system_init(void) +{ + return 0; +} + +void system_release(void) +{} + +void system_name(char* name, size_t size) +{ + strncpy(name, "native", size); +} \ No newline at end of file diff --git a/source/application/main/Classifier.cc b/source/application/main/Classifier.cc new file mode 100644 index 0000000..bc2c378 --- /dev/null +++ b/source/application/main/Classifier.cc @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Classifier.hpp" + +#include "hal.h" +#include "TensorFlowLiteMicro.hpp" + +#include +#include +#include +#include + +namespace arm { +namespace app { + + template + bool Classifier::_GetTopNResults(TfLiteTensor* tensor, + std::vector& vecResults, + uint32_t topNCount, + const std::vector & labels) + { + std::set> sortedSet; + + /* NOTE: inputVec's size verification against labels should be + * checked by the calling/public function. */ + T* tensorData = tflite::GetTensorData(tensor); + + /* Set initial elements. */ + for (uint32_t i = 0; i < topNCount; ++i) { + sortedSet.insert({tensorData[i], i}); + } + + /* Initialise iterator. */ + auto setFwdIter = sortedSet.begin(); + + /* Scan through the rest of elements with compare operations. */ + for (uint32_t i = topNCount; i < labels.size(); ++i) { + if (setFwdIter->first < tensorData[i]) { + sortedSet.erase(*setFwdIter); + sortedSet.insert({tensorData[i], i}); + setFwdIter = sortedSet.begin(); + } + } + + /* Final results' container. */ + vecResults = std::vector(topNCount); + + /* For getting the floating point values, we need quantization parameters. */ + QuantParams quantParams = GetTensorQuantParams(tensor); + + /* Reset the iterator to the largest element - use reverse iterator. */ + auto setRevIter = sortedSet.rbegin(); + + /* Populate results + * Note: we could combine this loop with the loop above, but that + * would, involve more multiplications and other operations. + **/ + for (size_t i = 0; i < vecResults.size(); ++i, ++setRevIter) { + double score = static_cast (setRevIter->first); + vecResults[i].m_normalisedVal = quantParams.scale * + (score - quantParams.offset); + vecResults[i].m_label = labels[setRevIter->second]; + vecResults[i].m_labelIdx = setRevIter->second; + } + + return true; + } + + template<> + bool Classifier::_GetTopNResults(TfLiteTensor* tensor, + std::vector& vecResults, + uint32_t topNCount, + const std::vector & labels) + { + std::set> sortedSet; + + /* NOTE: inputVec's size verification against labels should be + * checked by the calling/public function. */ + float* tensorData = tflite::GetTensorData(tensor); + + /* Set initial elements. */ + for (uint32_t i = 0; i < topNCount; ++i) { + sortedSet.insert({tensorData[i], i}); + } + + /* Initialise iterator. */ + auto setFwdIter = sortedSet.begin(); + + /* Scan through the rest of elements with compare operations. */ + for (uint32_t i = topNCount; i < labels.size(); ++i) { + if (setFwdIter->first < tensorData[i]) { + sortedSet.erase(*setFwdIter); + sortedSet.insert({tensorData[i], i}); + setFwdIter = sortedSet.begin(); + } + } + + /* Final results' container. */ + vecResults = std::vector(topNCount); + + /* Reset the iterator to the largest element - use reverse iterator. */ + auto setRevIter = sortedSet.rbegin(); + + /* Populate results + * Note: we could combine this loop with the loop above, but that + * would, involve more multiplications and other operations. + **/ + for (size_t i = 0; i < vecResults.size(); ++i, ++setRevIter) { + vecResults[i].m_normalisedVal = setRevIter->first; + vecResults[i].m_label = labels[setRevIter->second]; + vecResults[i].m_labelIdx = setRevIter->second; + } + + return true; + } + + template bool Classifier::_GetTopNResults(TfLiteTensor* tensor, + std::vector& vecResults, + uint32_t topNCount, const std::vector & labels); + + template bool Classifier::_GetTopNResults(TfLiteTensor* tensor, + std::vector& vecResults, + uint32_t topNCount, const std::vector & labels); + + bool Classifier::GetClassificationResults( + TfLiteTensor* outputTensor, + std::vector& vecResults, + const std::vector & labels, uint32_t topNCount) + { + if (outputTensor == nullptr) { + printf_err("Output vector is null pointer.\n"); + return false; + } + + uint32_t totalOutputSize = 1; + for (int inputDim = 0; inputDim < outputTensor->dims->size; inputDim++){ + totalOutputSize *= outputTensor->dims->data[inputDim]; + } + + /* Sanity checks. */ + if (totalOutputSize < topNCount) { + printf_err("Output vector is smaller than %u\n", topNCount); + return false; + } else if (totalOutputSize != labels.size()) { + printf_err("Output size doesn't match the labels' size\n"); + return false; + } + + bool resultState; + vecResults.clear(); + + /* Get the top N results. */ + switch (outputTensor->type) { + case kTfLiteUInt8: + resultState = _GetTopNResults(outputTensor, vecResults, topNCount, labels); + break; + case kTfLiteInt8: + resultState = _GetTopNResults(outputTensor, vecResults, topNCount, labels); + break; + case kTfLiteFloat32: + resultState = _GetTopNResults(outputTensor, vecResults, topNCount, labels); + break; + default: + printf_err("Tensor type %s not supported by classifier\n", TfLiteTypeGetName(outputTensor->type)); + return false; + } + + if (!resultState) { + printf_err("Failed to get sorted set\n"); + return false; + } + + return true; + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/application/main/Main.cc b/source/application/main/Main.cc new file mode 100644 index 0000000..6e1c620 --- /dev/null +++ b/source/application/main/Main.cc @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/****************************************************************************\ + * Main application file for ARM NPU on MPS3 board * +\****************************************************************************/ + +#include "hal.h" /* our hardware abstraction api */ +#include "TensorFlowLiteMicro.hpp" /* our inference logic api */ + +#include + +extern void main_loop(hal_platform& platform); + +#if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050) +__ASM(" .global __ARM_use_no_argv\n"); +#endif + +/* Print application information. */ +static void print_application_intro() +{ + info("%s\n", PRJ_DES_STR); + info("Target system design: %s\n", DESIGN_NAME); + info("Version %s Build date: " __DATE__ " @ " __TIME__ "\n", PRJ_VER_STR); + info("Copyright (C) ARM Ltd 2020. All rights reserved.\n\n"); +} + +int main () +{ + hal_platform platform; + data_acq_module dataAcq; + data_psn_module dataPsn; + platform_timer timer; + + /* Initialise the HAL and platform. */ + hal_init(&platform, &dataAcq, &dataPsn, &timer); + + if (0 == hal_platform_init(&platform)) { + /* Application information, UART should have been initialised. */ + print_application_intro(); + + /* Check the version of TensorFlow Lite Micro. */ + PrintTensorFlowVersion(); + + /* Run the application. */ + main_loop(platform); + } + + /* This is unreachable without errors. */ + info("program terminating...\n"); + + /* Release platform. */ + hal_platform_release(&platform); + return 0; +} + diff --git a/source/application/main/Mfcc.cc b/source/application/main/Mfcc.cc new file mode 100644 index 0000000..bf16159 --- /dev/null +++ b/source/application/main/Mfcc.cc @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Mfcc.hpp" + +#include "PlatformMath.hpp" + +#include + +namespace arm { +namespace app { +namespace audio { + + MfccParams::MfccParams( + const float samplingFreq, + const uint32_t numFbankBins, + const float melLoFreq, + const float melHiFreq, + const uint32_t numMfccFeats, + const uint32_t frameLen, + const bool useHtkMethod): + m_samplingFreq(samplingFreq), + m_numFbankBins(numFbankBins), + m_melLoFreq(melLoFreq), + m_melHiFreq(melHiFreq), + m_numMfccFeatures(numMfccFeats), + m_frameLen(frameLen), + + /* Smallest power of 2 >= frame length. */ + m_frameLenPadded(pow(2, ceil((log(frameLen)/log(2))))), + m_useHtkMethod(useHtkMethod) + {} + + std::string MfccParams::Str() + { + char strC[1024]; + snprintf(strC, sizeof(strC) - 1, "\n \ + \n\t Sampling frequency: %f\ + \n\t Number of filter banks: %u\ + \n\t Mel frequency limit (low): %f\ + \n\t Mel frequency limit (high): %f\ + \n\t Number of MFCC features: %u\ + \n\t Frame length: %u\ + \n\t Padded frame length: %u\ + \n\t Using HTK for Mel scale: %s\n", + this->m_samplingFreq, this->m_numFbankBins, this->m_melLoFreq, + this->m_melHiFreq, this->m_numMfccFeatures, this->m_frameLen, + this->m_frameLenPadded, this->m_useHtkMethod ? "yes" : "no"); + return std::string{strC}; + } + + MFCC::MFCC(const MfccParams& params): + _m_params(params), + _m_filterBankInitialised(false) + { + this->_m_buffer = std::vector( + this->_m_params.m_frameLenPadded, 0.0); + this->_m_frame = std::vector( + this->_m_params.m_frameLenPadded, 0.0); + this->_m_melEnergies = std::vector( + this->_m_params.m_numFbankBins, 0.0); + + this->_m_windowFunc = std::vector(this->_m_params.m_frameLen); + const float multiplier = 2 * M_PI / this->_m_params.m_frameLen; + + /* Create window function. */ + for (size_t i = 0; i < this->_m_params.m_frameLen; i++) { + this->_m_windowFunc[i] = (0.5 - (0.5 * + math::MathUtils::CosineF32(static_cast(i) * multiplier))); + } + + math::MathUtils::FftInitF32(this->_m_params.m_frameLenPadded, this->_m_fftInstance); + debug("Instantiated MFCC object: %s\n", this->_m_params.Str().c_str()); + } + + void MFCC::Init() + { + this->_InitMelFilterBank(); + } + + float MFCC::MelScale(const float freq, const bool useHTKMethod) + { + if (useHTKMethod) { + return 1127.0f * logf (1.0f + freq / 700.0f); + } else { + /* Slaney formula for mel scale. */ + + float mel = freq / ms_freqStep; + + if (freq >= ms_minLogHz) { + mel = ms_minLogMel + logf(freq / ms_minLogHz) / ms_logStep; + } + return mel; + } + } + + float MFCC::InverseMelScale(const float melFreq, const bool useHTKMethod) + { + if (useHTKMethod) { + return 700.0f * (expf (melFreq / 1127.0f) - 1.0f); + } else { + /* Slaney formula for mel scale. */ + float freq = ms_freqStep * melFreq; + + if (melFreq >= ms_minLogMel) { + freq = ms_minLogHz * expf(ms_logStep * (melFreq - ms_minLogMel)); + } + return freq; + } + } + + + bool MFCC::ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) + { + const size_t numBanks = melEnergies.size(); + + if (numBanks != filterBankFilterFirst.size() || + numBanks != filterBankFilterLast.size()) { + printf_err("unexpected filter bank lengths\n"); + return false; + } + + for (size_t bin = 0; bin < numBanks; ++bin) { + auto filterBankIter = melFilterBank[bin].begin(); + float melEnergy = FLT_MIN; /* Avoid log of zero at later stages */ + int32_t firstIndex = filterBankFilterFirst[bin]; + int32_t lastIndex = filterBankFilterLast[bin]; + + for (int i = firstIndex; i <= lastIndex; i++) { + float energyRep = math::MathUtils::SqrtF32(fftVec[i]); + melEnergy += (*filterBankIter++ * energyRep); + } + + melEnergies[bin] = melEnergy; + } + + return true; + } + + void MFCC::ConvertToLogarithmicScale(std::vector& melEnergies) + { + for (size_t bin = 0; bin < melEnergies.size(); ++bin) { + melEnergies[bin] = logf(melEnergies[bin]); + } + } + + void MFCC::_ConvertToPowerSpectrum() + { + const uint32_t halfDim = this->_m_params.m_frameLenPadded / 2; + + /* Handle this special case. */ + float firstEnergy = this->_m_buffer[0] * this->_m_buffer[0]; + float lastEnergy = this->_m_buffer[1] * this->_m_buffer[1]; + + math::MathUtils::ComplexMagnitudeSquaredF32( + this->_m_buffer.data(), + this->_m_buffer.size(), + this->_m_buffer.data(), + this->_m_buffer.size()/2); + + this->_m_buffer[0] = firstEnergy; + this->_m_buffer[halfDim] = lastEnergy; + } + + std::vector MFCC::CreateDCTMatrix( + const int32_t inputLength, + const int32_t coefficientCount) + { + std::vector dctMatix(inputLength * coefficientCount); + + const float normalizer = math::MathUtils::SqrtF32(2.0f/inputLength); + const float angleIncr = M_PI/inputLength; + float angle = 0; + + for (int32_t k = 0, m = 0; k < coefficientCount; k++, m += inputLength) { + for (int32_t n = 0; n < inputLength; n++) { + dctMatix[m+n] = normalizer * + math::MathUtils::CosineF32((n + 0.5) * angle); + } + angle += angleIncr; + } + + return dctMatix; + } + + float MFCC::GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + const bool useHTKMethod) + { + UNUSED(leftMel); + UNUSED(rightMel); + UNUSED(useHTKMethod); + + /* By default, no normalisation => return 1 */ + return 1.f; + } + + void MFCC::_InitMelFilterBank() + { + if (!this->_IsMelFilterBankInited()) { + this->_m_melFilterBank = this->_CreateMelFilterBank(); + this->_m_dctMatrix = this->CreateDCTMatrix( + this->_m_params.m_numFbankBins, + this->_m_params.m_numMfccFeatures); + this->_m_filterBankInitialised = true; + } + } + + bool MFCC::_IsMelFilterBankInited() + { + return this->_m_filterBankInitialised; + } + + void MFCC::_MfccComputePreFeature(const std::vector& audioData) + { + this->_InitMelFilterBank(); + + /* TensorFlow way of normalizing .wav data to (-1, 1). */ + constexpr float normaliser = 1.0/(1<<15); + for (size_t i = 0; i < this->_m_params.m_frameLen; i++) { + this->_m_frame[i] = static_cast(audioData[i]) * normaliser; + } + + /* Apply window function to input frame. */ + for(size_t i = 0; i < this->_m_params.m_frameLen; i++) { + this->_m_frame[i] *= this->_m_windowFunc[i]; + } + + /* Set remaining frame values to 0. */ + std::fill(this->_m_frame.begin() + this->_m_params.m_frameLen,this->_m_frame.end(), 0); + + /* Compute FFT. */ + math::MathUtils::FftF32(this->_m_frame, this->_m_buffer, this->_m_fftInstance); + + /* Convert to power spectrum. */ + this->_ConvertToPowerSpectrum(); + + /* Apply mel filterbanks. */ + if (!this->ApplyMelFilterBank(this->_m_buffer, + this->_m_melFilterBank, + this->_m_filterBankFilterFirst, + this->_m_filterBankFilterLast, + this->_m_melEnergies)) { + printf_err("Failed to apply MEL filter banks\n"); + } + + /* Convert to logarithmic scale. */ + this->ConvertToLogarithmicScale(this->_m_melEnergies); + } + + std::vector MFCC::MfccCompute(const std::vector& audioData) + { + this->_MfccComputePreFeature(audioData); + + std::vector mfccOut(this->_m_params.m_numMfccFeatures); + + float * ptrMel = this->_m_melEnergies.data(); + float * ptrDct = this->_m_dctMatrix.data(); + float * ptrMfcc = mfccOut.data(); + + /* Take DCT. Uses matrix mul. */ + for (size_t i = 0, j = 0; i < mfccOut.size(); + ++i, j += this->_m_params.m_numFbankBins) { + *ptrMfcc++ = math::MathUtils::DotProductF32( + ptrDct + j, + ptrMel, + this->_m_params.m_numFbankBins); + } + return mfccOut; + } + + std::vector> MFCC::_CreateMelFilterBank() + { + size_t numFftBins = this->_m_params.m_frameLenPadded / 2; + float fftBinWidth = static_cast(this->_m_params.m_samplingFreq) / this->_m_params.m_frameLenPadded; + + float melLowFreq = MFCC::MelScale(this->_m_params.m_melLoFreq, + this->_m_params.m_useHtkMethod); + float melHighFreq = MFCC::MelScale(this->_m_params.m_melHiFreq, + this->_m_params.m_useHtkMethod); + float melFreqDelta = (melHighFreq - melLowFreq) / (this->_m_params.m_numFbankBins + 1); + + std::vector thisBin = std::vector(numFftBins); + std::vector> melFilterBank( + this->_m_params.m_numFbankBins); + this->_m_filterBankFilterFirst = + std::vector(this->_m_params.m_numFbankBins); + this->_m_filterBankFilterLast = + std::vector(this->_m_params.m_numFbankBins); + + for (size_t bin = 0; bin < this->_m_params.m_numFbankBins; bin++) { + float leftMel = melLowFreq + bin * melFreqDelta; + float centerMel = melLowFreq + (bin + 1) * melFreqDelta; + float rightMel = melLowFreq + (bin + 2) * melFreqDelta; + + int32_t firstIndex = -1; + int32_t lastIndex = -1; + const float normaliser = this->GetMelFilterBankNormaliser(leftMel, rightMel, this->_m_params.m_useHtkMethod); + + for (size_t i = 0; i < numFftBins; i++) { + float freq = (fftBinWidth * i); /* Center freq of this fft bin. */ + float mel = MFCC::MelScale(freq, this->_m_params.m_useHtkMethod); + thisBin[i] = 0.0; + + if (mel > leftMel && mel < rightMel) { + float weight; + if (mel <= centerMel) { + weight = (mel - leftMel) / (centerMel - leftMel); + } else { + weight = (rightMel - mel) / (rightMel - centerMel); + } + + thisBin[i] = weight * normaliser; + if (firstIndex == -1) { + firstIndex = i; + } + lastIndex = i; + } + } + + this->_m_filterBankFilterFirst[bin] = firstIndex; + this->_m_filterBankFilterLast[bin] = lastIndex; + + /* Copy the part we care about. */ + for (int32_t i = firstIndex; i <= lastIndex; i++) { + melFilterBank[bin].push_back(thisBin[i]); + } + } + + return melFilterBank; + } + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/application/main/PlatformMath.cc b/source/application/main/PlatformMath.cc new file mode 100644 index 0000000..a9f5049 --- /dev/null +++ b/source/application/main/PlatformMath.cc @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "PlatformMath.hpp" + +#if 0 == ARM_DSP_AVAILABLE + #include + #include +#endif /* 0 == ARM_DSP_AVAILABLE */ + +namespace arm { +namespace app { +namespace math { + + float MathUtils::CosineF32(float radians) + { +#if ARM_DSP_AVAILABLE + return arm_cos_f32(radians); +#else /* ARM_DSP_AVAILABLE */ + return cos(radians); +#endif /* ARM_DSP_AVAILABLE */ + } + + float MathUtils::SqrtF32(float input) + { +#if ARM_DSP_AVAILABLE + float output = 0.f; + arm_sqrt_f32(input, &output); + return output; +#else /* ARM_DSP_AVAILABLE */ + return sqrtf(input); +#endif /* ARM_DSP_AVAILABLE */ + } + + float MathUtils::MeanF32(float* ptrSrc, const uint32_t srcLen) + { + if (!srcLen) { + return 0.f; + } + +#if ARM_DSP_AVAILABLE + float result = 0.f; + arm_mean_f32(ptrSrc, srcLen, &result); + return result; +#else /* ARM_DSP_AVAILABLE */ + float acc = std::accumulate(ptrSrc, ptrSrc + srcLen, 0.0); + return acc/srcLen; +#endif /* ARM_DSP_AVAILABLE */ + } + + float MathUtils::StdDevF32(float* ptrSrc, const uint32_t srcLen, + const float mean) + { + if (!srcLen) { + return 0.f; + } +#if ARM_DSP_AVAILABLE + /** + * Note Standard deviation calculation can be off + * by > 0.01 but less than < 0.1, according to + * preliminary findings. + **/ + UNUSED(mean); + float stdDev = 0; + arm_std_f32(ptrSrc, srcLen, &stdDev); + return stdDev; +#else /* ARM_DSP_AVAILABLE */ + auto VarianceFunction = [=](float acc, const float value) { + return acc + (((value - mean) * (value - mean))/ srcLen); + }; + + float acc = std::accumulate(ptrSrc, ptrSrc + srcLen, 0.0, + VarianceFunction); + + return sqrtf(acc); +#endif /* ARM_DSP_AVAILABLE */ + } + + bool MathUtils::FftInitF32(const uint16_t fftLen, arm::app::math::FftInstance& fftInstance) + { +#if ARM_DSP_AVAILABLE + if (!fftInstance.initialised) { + arm_status status = arm_rfft_fast_init_f32(&fftInstance.instance, fftLen); + + if (ARM_MATH_SUCCESS != status) { + return false; + } + fftInstance.initialised = true; + } +#else + UNUSED(fftLen); + UNUSED(fftInstance); +#endif /* ARM_DSP_AVAILABLE */ + return true; + } + + void MathUtils::FftF32(std::vector& input, + std::vector& fftOutput, + arm::app::math::FftInstance& fftInstance) + { +#if ARM_DSP_AVAILABLE + arm_rfft_fast_f32(&fftInstance.instance, input.data(), fftOutput.data(), 0); +#else + UNUSED(fftInstance); + const int inputLength = input.size(); + + for (int k = 0; k <= inputLength / 2; k++) { + float sumReal = 0, sumImag = 0; + + for (int t = 0; t < inputLength; t++) { + float angle = 2 * M_PI * t * k / inputLength; + sumReal += input[t] * cosf(angle); + sumImag += -input[t] * sinf(angle); + } + + /* Arrange output to [real0, realN/2, real1, im1, real2, im2, ...] */ + if (k == 0) { + fftOutput[0] = sumReal; + } else if (k == inputLength / 2) { + fftOutput[1] = sumReal; + } else { + fftOutput[k*2] = sumReal; + fftOutput[k*2 + 1] = sumImag; + }; + } +#endif /* ARM_DSP_AVAILABLE */ + } + + void MathUtils::VecLogarithmF32(std::vector & input, + std::vector & output) + { +#if ARM_DSP_AVAILABLE + arm_vlog_f32(input.data(), output.data(), + output.size()); +#else /* ARM_DSP_AVAILABLE */ + for (auto in = input.begin(), out = output.begin(); + in != input.end(); ++in, ++out) { + *out = logf(*in); + } +#endif /* ARM_DSP_AVAILABLE */ + } + + float MathUtils::DotProductF32(float* srcPtrA, float* srcPtrB, + const uint32_t srcLen) + { + float output = 0.f; + +#if ARM_DSP_AVAILABLE + arm_dot_prod_f32(srcPtrA, srcPtrB, srcLen, &output); +#else /* ARM_DSP_AVAILABLE */ + for (uint32_t i = 0; i < srcLen; ++i) { + output += *srcPtrA++ * *srcPtrB++; + } +#endif /* ARM_DSP_AVAILABLE */ + + return output; + } + + bool MathUtils::ComplexMagnitudeSquaredF32(float* ptrSrc, + const uint32_t srcLen, + float* ptrDst, + const uint32_t dstLen) + { + if (dstLen < srcLen/2) { + printf_err("dstLen must be greater than srcLen/2"); + return false; + } + +#if ARM_DSP_AVAILABLE + arm_cmplx_mag_squared_f32(ptrSrc, ptrDst, srcLen/2); +#else /* ARM_DSP_AVAILABLE */ + for (uint32_t j = 0; j < srcLen; ++j) { + const float real = *ptrSrc++; + const float im = *ptrSrc++; + *ptrDst++ = real*real + im*im; + } +#endif /* ARM_DSP_AVAILABLE */ + return true; + } + +} /* namespace math */ +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/application/main/Profiler.cc b/source/application/main/Profiler.cc new file mode 100644 index 0000000..f364759 --- /dev/null +++ b/source/application/main/Profiler.cc @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Profiler.hpp" + +#include +#include +#include + +namespace arm { +namespace app { + + template + static void writeStatLine(std::ostringstream& s, + const char* desc, + T total, + uint32_t samples, + T min, + T max) + { + s << "\t" << desc << total << " / " + << ((double)total / samples) << " / " + << min << " / " << max << std::endl; + } + + Profiler::Profiler(hal_platform* platform, const char* name = "Unknown") + : _m_name(name) + { + if (platform && platform->inited) { + this->_m_pPlatform = platform; + this->Reset(); + } else { + printf_err("Profiler %s initialised with invalid platform\n", + this->_m_name.c_str()); + } + } + + bool Profiler::StartProfiling(const char* name) + { + if (name) { + this->SetName(name); + } + if (this->_m_pPlatform && !this->_m_started) { + this->_m_pPlatform->timer->reset(); + this->_m_tstampSt = this->_m_pPlatform->timer->start_profiling(); + this->_m_started = true; + return true; + } + printf_err("Failed to start profiler %s\n", this->_m_name.c_str()); + return false; + } + + bool Profiler::StopProfiling() + { + if (this->_m_pPlatform && this->_m_started) { + this->_m_tstampEnd = this->_m_pPlatform->timer->stop_profiling(); + this->_m_started = false; + + this->_AddProfilingUnit(this->_m_tstampSt, this->_m_tstampEnd, this->_m_name); + + return true; + } + printf_err("Failed to stop profiler %s\n", this->_m_name.c_str()); + return false; + } + + bool Profiler::StopProfilingAndReset() + { + if (this->StopProfiling()) { + this->Reset(); + return true; + } + printf_err("Failed to stop profiler %s\n", this->_m_name.c_str()); + return false; + } + + void Profiler::Reset() + { + this->_m_started = false; + memset(&this->_m_tstampSt, 0, sizeof(this->_m_tstampSt)); + memset(&this->_m_tstampEnd, 0, sizeof(this->_m_tstampEnd)); + } + + std::string Profiler::GetResultsAndReset() + { + std::ostringstream strResults; + + for (const auto& item: this->_m_series) { + auto name = item.first; + ProfilingSeries series = item.second; + + uint32_t samplesNum = series.size(); + + uint64_t totalNpuCycles = 0; /* Total NPU cycles (idle + active). */ + uint64_t totalActiveNpuCycles = 0; /* Active NPU cycles. */ + uint64_t totalCpuCycles = 0; /* Total CPU cycles. */ + time_t totalTimeMs = 0; + + uint64_t minActiveNpuCycles = series[0].activeNpuCycles; + uint64_t minIdleNpuCycles = series[0].npuCycles - minActiveNpuCycles; + uint64_t minActiveCpuCycles = series[0].cpuCycles - minActiveNpuCycles; + time_t minTimeMs = series[0].time; + + uint64_t maxIdleNpuCycles = 0; + uint64_t maxActiveNpuCycles = 0; + uint64_t maxActiveCpuCycles = 0; + time_t maxTimeMs = 0; + + for(ProfilingUnit& unit: series){ + totalNpuCycles += unit.npuCycles; + totalActiveNpuCycles += unit.activeNpuCycles; + totalCpuCycles += unit.cpuCycles; + totalTimeMs += unit.time; + + maxActiveNpuCycles = std::max(maxActiveNpuCycles, + unit.activeNpuCycles); + maxIdleNpuCycles = std::max(maxIdleNpuCycles, + unit.npuCycles - maxActiveNpuCycles); + maxActiveCpuCycles = std::max(maxActiveCpuCycles, + unit.cpuCycles - maxActiveNpuCycles); + maxTimeMs = std::max(maxTimeMs, unit.time); + + minActiveNpuCycles = std::min(minActiveNpuCycles, + unit.activeNpuCycles); + minIdleNpuCycles = std::min(minIdleNpuCycles, + unit.npuCycles - minActiveNpuCycles); + minActiveCpuCycles = std::min(minActiveCpuCycles, + unit.cpuCycles - minActiveNpuCycles); + minTimeMs = std::min(minTimeMs, unit.time); + } + + strResults << "Profile for " << name << ": " << std::endl; + + if (samplesNum > 1) { + strResults << "\tSamples: " << samplesNum << std::endl; + strResults << "\t Total / Avg./ Min / Max" + << std::endl; + + writeStatLine(strResults, "Active NPU cycles: ", + totalActiveNpuCycles, samplesNum, + minActiveNpuCycles, maxActiveNpuCycles); + + writeStatLine(strResults, "Idle NPU cycles: ", + (totalNpuCycles - totalActiveNpuCycles), + samplesNum, minIdleNpuCycles, maxIdleNpuCycles); + +#if defined(CPU_PROFILE_ENABLED) + writeStatLine(strResults, "Active CPU cycles (approx): ", + (totalCpuCycles - totalActiveNpuCycles), + samplesNum, minActiveCpuCycles, + maxActiveCpuCycles); + + writeStatLine(strResults, "Time in ms: ", + totalTimeMs, samplesNum, minTimeMs, maxTimeMs); +#endif + } else { + strResults << "\tActive NPU cycles: " << totalActiveNpuCycles + << std::endl; + strResults << "\tIdle NPU cycles: " + << (totalNpuCycles - totalActiveNpuCycles) + << std::endl; +#if defined(CPU_PROFILE_ENABLED) + strResults << "\tActive CPU cycles: " + << (totalCpuCycles - totalActiveNpuCycles) + << " (approx)" << std::endl; + + strResults << "\tTime in ms: " << totalTimeMs << std::endl; +#endif + } + } + this->Reset(); + return strResults.str(); + } + + void Profiler::SetName(const char* str) + { + this->_m_name = std::string(str); + } + + void Profiler::_AddProfilingUnit(time_counter start, time_counter end, + const std::string& name) + { + platform_timer * timer = this->_m_pPlatform->timer; + + struct ProfilingUnit unit; + + if (timer->cap.npu_cycles && timer->get_npu_total_cycle_diff && + timer->get_npu_active_cycle_diff) + { + unit.npuCycles = timer->get_npu_total_cycle_diff(&start, &end); + unit.activeNpuCycles = timer->get_npu_active_cycle_diff(&start, &end); + } + + if (timer->cap.cpu_cycles && timer->get_cpu_cycle_diff) { + unit.cpuCycles = timer->get_cpu_cycle_diff(&start, &end); + } + + if (timer->cap.duration_ms && timer->get_duration_ms) { + unit.time = timer->get_duration_ms(&start, &end); + } + + this->_m_series[name].emplace_back(unit); + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/application/main/UseCaseCommonUtils.cc b/source/application/main/UseCaseCommonUtils.cc new file mode 100644 index 0000000..4ea5e4d --- /dev/null +++ b/source/application/main/UseCaseCommonUtils.cc @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseCommonUtils.hpp" + +#include "InputFiles.hpp" + +namespace arm { +namespace app { + + bool RunInference(hal_platform& platform, arm::app::Model& model) + { + Profiler profiler{&platform, "Inference"}; + profiler.StartProfiling(); + + bool runInf = model.RunInference(); + + profiler.StopProfiling(); + std::string profileResults = profiler.GetResultsAndReset(); + info("%s\n", profileResults.c_str()); + + return runInf; + } + + int ReadUserInputAsInt(hal_platform& platform) + { + char chInput[128]; + memset(chInput, 0, sizeof(chInput)); + + platform.data_acq->get_input(chInput, sizeof(chInput)); + return atoi(chInput); + } + + void DumpTensor(TfLiteTensor* tensor, const size_t lineBreakForNumElements) + { + char strhex[8]; + std::string strdump; + + if (!tensor) { + printf_err("invalid tensor\n"); + return; + } + + const uint32_t tensorSz = tensor->bytes; + const uint8_t* tensorData = tflite::GetTensorData(tensor); + + for (size_t i = 0; i < tensorSz; ++i) { + if (0 == i % lineBreakForNumElements) { + printf("%s\n\t", strdump.c_str()); + strdump.clear(); + } + snprintf(strhex, sizeof(strhex) - 1, + "0x%02x, ", tensorData[i]); + strdump += std::string(strhex); + } + + if (strdump.size()) { + printf("%s\n", strdump.c_str()); + } + } + + bool ListFilesHandler(ApplicationContext& ctx) + { + auto& model = ctx.Get("model"); + auto& platform = ctx.Get("platform"); + + constexpr uint32_t dataPsnTxtStartX = 20; + constexpr uint32_t dataPsnTxtStartY = 40; + + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + /* Clear the LCD */ + platform.data_psn->clear(COLOR_BLACK); + + /* Show the total number of embedded files. */ + std::string strNumFiles = std::string{"Total Number of Files: "} + + std::to_string(NUMBER_OF_FILES); + platform.data_psn->present_data_text(strNumFiles.c_str(), + strNumFiles.size(), + dataPsnTxtStartX, + dataPsnTxtStartY, + 0); + +#if NUMBER_OF_FILES > 0 + constexpr uint32_t dataPsnTxtYIncr = 16; + info("List of Files:\n"); + uint32_t yVal = dataPsnTxtStartY + dataPsnTxtYIncr; + for (uint32_t i = 0; i < NUMBER_OF_FILES; ++i, yVal += dataPsnTxtYIncr) { + + std::string currentFilename{get_filename(i)}; + platform.data_psn->present_data_text(currentFilename.c_str(), + currentFilename.size(), + dataPsnTxtStartX, yVal, 0); + + info("\t%u => %s\n", i, currentFilename.c_str()); + } +#endif /* NUMBER_OF_FILES > 0 */ + + return true; + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/application/main/include/AppContext.hpp b/source/application/main/include/AppContext.hpp new file mode 100644 index 0000000..588dfaa --- /dev/null +++ b/source/application/main/include/AppContext.hpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef APP_CTX_HPP +#define APP_CTX_HPP + +#include +#include + +namespace arm { +namespace app { + + class IAttribute + { + public: + virtual ~IAttribute() = default; + }; + + template + class Attribute : public IAttribute + { + public: + ~Attribute() override = default; + + explicit Attribute(const T value): _m_value(value){} + + T Get() + { + return _m_value; + } + private: + T _m_value; + }; + + /* Application context class */ + class ApplicationContext { + public: + + /** + * @brief Saves given value as a named attribute in the context. + * @tparam T value type. + * @param[in] name Context attribute name. + * @param[in] object Value to save in the context. + */ + template + void Set(const std::string &name, T object) + { + this->_m_attributes[name] = new Attribute(object); + } + + /** + * @brief Gets the saved attribute from the context by the given name. + * @tparam T value type. + * @param[in] name Context attribute name. + * @return Value saved in the context. + */ + template + T Get(const std::string &name) + { + auto a = (Attribute*)_m_attributes[name]; + return a->Get(); + } + + /** + * @brief Checks if an attribute for a given name exists in the context. + * @param[in] name Attribute name. + * @return true if attribute exists, false otherwise + */ + bool Has(const std::string& name) + { + return _m_attributes.find(name) != _m_attributes.end(); + } + + ApplicationContext() = default; + + ~ApplicationContext() { + for (auto& attribute : _m_attributes) + delete attribute.second; + + this->_m_attributes.clear(); + } + private: + std::map _m_attributes; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* APP_CTX_HPP */ diff --git a/source/application/main/include/AudioUtils.hpp b/source/application/main/include/AudioUtils.hpp new file mode 100644 index 0000000..cba981d --- /dev/null +++ b/source/application/main/include/AudioUtils.hpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AUDIO_UTILS_HPP +#define AUDIO_UTILS_HPP + +#include +#include + +namespace arm { +namespace app { +namespace audio { + + template + class SlidingWindow { + public: + + /** + * @brief Creates the window slider through the given data. + * + * @param[in] data Pointer to the data to slide through. + * @param[in] dataSize Size in T type elements wise. + * @param[in] windowSize Sliding window size in T type wise elements. + * @param[in] stride Stride size in T type wise elements. + */ + SlidingWindow(T *data, size_t dataSize, + size_t windowSize, size_t stride) { + m_start = data; + m_dataSize = dataSize; + m_size = windowSize; + m_stride = stride; + } + + SlidingWindow() = default; + + ~SlidingWindow() = default; + + /** + * @brief Get the next data window. + * @return Pointer to the next window, if next window is not available nullptr is returned. + */ + virtual T *Next() { + if (HasNext()) { + m_count++; + return m_start + Index() * m_stride; + } else { + return nullptr; + } + } + + /** + * @brief Checks if the next data portion is available. + * @return true if next data portion is available. + */ + virtual bool HasNext() { + return m_size + m_count * m_stride <= m_dataSize; + } + + /** + * @brief Reset the slider to the initial position. + */ + virtual void Reset() { + m_count = 0; + } + + /** + * @brief Resets the slider to the start of the new data. + * New data size MUST be the same as the old one. + * @param[in] newStart Pointer to the new data to slide through. + */ + virtual void Reset(T *newStart) { + m_start = newStart; + Reset(); + } + + /** + * @brief Gets current index of the sliding window. + * @return Current position of the sliding window in number of strides. + */ + size_t Index() { + return m_count == 0? 0: m_count - 1; + } + + /** + * @brief Gets the index from the start of the data where the next window will begin. + * While Index() returns the index of sliding window itself this function + * returns the index of the data element itself. + * @return Index from the start of the data where the next sliding window will begin. + */ + virtual uint32_t NextWindowStartIndex() { + return m_count == 0? 0: ((m_count) * m_stride); + } + + /** + * @brief Go to given sliding window index. + * @param[in] index New position of the sliding window. If index is invalid + * (greater than possible range of strides) then next call to Next() will return nullptr. + */ + void FastForward(size_t index) { + m_count = index; + } + + /** + * @brief Calculates whole number of times the window can stride through the given data. + * @return Maximum number of whole strides. + */ + size_t TotalStrides() { + if (m_size > m_dataSize) { + return 0; + } + return ((m_dataSize - m_size)/m_stride); + } + + /** + * @brief Calculates number of times the window can stride through the given data. + * May not be a whole number. + * @return Number of strides to cover all data. + */ + float FractionalTotalStrides() { + if (this->m_dataSize < this->m_size) { + return 0; + } else { + return ((this->m_dataSize - this->m_size)/ static_cast(this->m_stride)); + } + } + + protected: + T *m_start = nullptr; + size_t m_dataSize = 0; + size_t m_size = 0; + size_t m_stride = 0; + size_t m_count = 0; + }; + + /* + * Sliding window for ASR will cover the whole of the input, even if + * this means the last window is not a full window length. + */ + template + class ASRSlidingWindow : public SlidingWindow { + public: + using SlidingWindow::SlidingWindow; + + /** + * @brief Checks if the next data portion is available. + * @return true if next data portion is available. + */ + bool HasNext() { + return this->m_count < 1 + this->FractionalTotalStrides() && (this->NextWindowStartIndex() < this->m_dataSize); + } + }; + + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* AUDIO_UTILS_HPP */ \ No newline at end of file diff --git a/source/application/main/include/ClassificationResult.hpp b/source/application/main/include/ClassificationResult.hpp new file mode 100644 index 0000000..eae28e4 --- /dev/null +++ b/source/application/main/include/ClassificationResult.hpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef CLASSIFICATION_RESULT_HPP +#define CLASSIFICATION_RESULT_HPP + +#include + +namespace arm { +namespace app { + + /** + * @brief Class representing a single classification result. + */ + class ClassificationResult { + public: + double m_normalisedVal = 0.0; + std::string m_label; + uint32_t m_labelIdx = 0; + + ClassificationResult() = default; + ~ClassificationResult() = default; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* CLASSIFICATION_RESULT_HPP */ \ No newline at end of file diff --git a/source/application/main/include/Classifier.hpp b/source/application/main/include/Classifier.hpp new file mode 100644 index 0000000..510e6f9 --- /dev/null +++ b/source/application/main/include/Classifier.hpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef CLASSIFIER_HPP +#define CLASSIFIER_HPP + +#include "ClassificationResult.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include + +namespace arm { +namespace app { + + /** + * @brief Classifier - a helper class to get certain number of top + * results from the output vector from a classification NN. + **/ + class Classifier{ + public: + /** @brief Constructor. */ + Classifier() = default; + + /** + * @brief Gets the top N classification results from the + * output vector. + * @param[in] outputTensor Inference output tensor from an NN model. + * @param[out] vecResults A vector of classification results. + * populated by this function. + * @param[in] labels Labels vector to match classified classes. + * @param[in] topNCount Number of top classifications to pick. Default is 1. + * @return true if successful, false otherwise. + **/ + virtual bool GetClassificationResults( + TfLiteTensor* outputTensor, + std::vector& vecResults, + const std::vector & labels, uint32_t topNCount); + + private: + /** + * @brief Utility function that gets the top N classification results from the + * output vector. + * @tparam T value type + * @param[in] tensor Inference output tensor from an NN model. + * @param[out] vecResults A vector of classification results + * populated by this function. + * @param[in] topNCount Number of top classifications to pick. + * @param[in] labels Labels vector to match classified classes. + * @return true if successful, false otherwise. + **/ + template + bool _GetTopNResults(TfLiteTensor* tensor, + std::vector& vecResults, + uint32_t topNCount, + const std::vector & labels); + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* CLASSIFIER_HPP */ diff --git a/source/application/main/include/DataStructures.hpp b/source/application/main/include/DataStructures.hpp new file mode 100644 index 0000000..5cc8b5e --- /dev/null +++ b/source/application/main/include/DataStructures.hpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATA_STRUCTURES_HPP +#define DATA_STRUCTURES_HPP + +#include "hal.h" + +#include + +namespace arm { +namespace app { + + /** + * Class Array2d is a data structure that represents a two dimensional array. + * The data is allocated in contiguous memory, arranged row-wise + * and individual elements can be accessed with the () operator. + * For example a two dimensional array D of size (M, N) can be accessed: + * + * _|<------------- col size = N -------->| + * | D(r=0, c=0) D(r=0, c=1)... D(r=0, c=N) + * | D(r=1, c=0) D(r=1, c=1)... D(r=1, c=N) + * | ... + * row size = M ... + * | ... + * _ D(r=M, c=0) D(r=M, c=1)... D(r=M, c=N) + * + */ + template + class Array2d { + public: + /** + * @brief Creates the array2d with the given sizes. + * @param[in] rows Number of rows. + * @param[in] cols Number of columns. + */ + Array2d(unsigned rows, unsigned cols) + { + if (rows == 0 || cols == 0) { + printf_err("Array2d constructor has 0 size.\n"); + _m_data = nullptr; + return; + } + _m_rows = rows; + _m_cols = cols; + _m_data = new T[rows * cols]; + } + + ~Array2d() + { + delete[] _m_data; + } + + T& operator() (unsigned int row, unsigned int col) + { +#if defined(DEBUG) + if (row >= _m_rows || col >= _m_cols || _m_data == nullptr) { + printf_err("Array2d subscript out of bounds.\n"); + } +#endif /* defined(DEBUG) */ + return _m_data[_m_cols * row + col]; + } + + T operator() (unsigned int row, unsigned int col) const + { +#if defined(DEBUG) + if (row >= _m_rows || col >= _m_cols || _m_data == nullptr) { + printf_err("const Array2d subscript out of bounds.\n"); + } +#endif /* defined(DEBUG) */ + return _m_data[_m_cols * row + col]; + } + + /** + * @brief Gets rows number of the current array2d. + * @return Number of rows. + */ + size_t size(size_t dim) + { + switch (dim) + { + case 0: + return _m_rows; + case 1: + return _m_cols; + default: + return 0; + } + } + + /** + * @brief Gets the array2d total size. + */ + size_t totalSize() + { + return _m_rows * _m_cols; + } + + /** + * array2d iterator. + */ + using iterator=T*; + using const_iterator=T const*; + + iterator begin() { return _m_data; } + iterator end() { return _m_data + totalSize(); } + const_iterator begin() const { return _m_data; } + const_iterator end() const { return _m_data + totalSize(); }; + + private: + size_t _m_rows; + size_t _m_cols; + T* _m_data; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* DATA_STRUCTURES_HPP */ \ No newline at end of file diff --git a/source/application/main/include/Mfcc.hpp b/source/application/main/include/Mfcc.hpp new file mode 100644 index 0000000..6379fab --- /dev/null +++ b/source/application/main/include/Mfcc.hpp @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MFCC_HPP +#define MFCC_HPP + +#include "PlatformMath.hpp" + +#include +#include +#include +#include +#include + +namespace arm { +namespace app { +namespace audio { + + /* MFCC's consolidated parameters. */ + class MfccParams { + public: + float m_samplingFreq; + uint32_t m_numFbankBins; + float m_melLoFreq; + float m_melHiFreq; + uint32_t m_numMfccFeatures; + uint32_t m_frameLen; + uint32_t m_frameLenPadded; + bool m_useHtkMethod; + + /** @brief Constructor */ + MfccParams(float samplingFreq, uint32_t numFbankBins, + float melLoFreq, float melHiFreq, + uint32_t numMfccFeats, uint32_t frameLen, + bool useHtkMethod); + + MfccParams() = delete; + + ~MfccParams() = default; + + /** @brief String representation of parameters */ + std::string Str(); + }; + + /** + * @brief Class for MFCC feature extraction. + * Based on https://github.com/ARM-software/ML-KWS-for-MCU/blob/master/Deployment/Source/MFCC/mfcc.cpp + * This class is designed to be generic and self-sufficient but + * certain calculation routines can be overridden to accommodate + * use-case specific requirements. + */ + class MFCC { + public: + /** + * @brief Constructor + * @param[in] params MFCC parameters + */ + explicit MFCC(const MfccParams& params); + + MFCC() = delete; + + ~MFCC() = default; + + /** + * @brief Extract MFCC features for one single small frame of + * audio data e.g. 640 samples. + * @param[in] audioData Vector of audio samples to calculate + * features for. + * @return Vector of extracted MFCC features. + **/ + std::vector MfccCompute(const std::vector& audioData); + + /** @brief Initialise. */ + void Init(); + + /** + * @brief Extract MFCC features and quantise for one single small + * frame of audio data e.g. 640 samples. + * @param[in] audioData Vector of audio samples to calculate + * features for. + * @param[in] quantScale Quantisation scale. + * @param[in] quantOffset Quantisation offset. + * @return Vector of extracted quantised MFCC features. + **/ + template + std::vector MfccComputeQuant(const std::vector& audioData, + const float quantScale, + const int quantOffset) + { + this->_MfccComputePreFeature(audioData); + float minVal = std::numeric_limits::min(); + float maxVal = std::numeric_limits::max(); + + std::vector mfccOut(this->_m_params.m_numMfccFeatures); + const size_t numFbankBins = this->_m_params.m_numFbankBins; + + /* Take DCT. Uses matrix mul. */ + for (size_t i = 0, j = 0; i < mfccOut.size(); ++i, j += numFbankBins) { + float sum = 0; + for (size_t k = 0; k < numFbankBins; ++k) { + sum += this->_m_dctMatrix[j + k] * this->_m_melEnergies[k]; + } + /* Quantize to T. */ + sum = std::round((sum / quantScale) + quantOffset); + mfccOut[i] = static_cast(std::min(std::max(sum, minVal), maxVal)); + } + + return mfccOut; + } + + /* Constants */ + static constexpr float ms_logStep = /*logf(6.4)*/ 1.8562979903656 / 27.0; + static constexpr float ms_freqStep = 200.0 / 3; + static constexpr float ms_minLogHz = 1000.0; + static constexpr float ms_minLogMel = ms_minLogHz / ms_freqStep; + + protected: + /** + * @brief Project input frequency to Mel Scale. + * @param[in] freq Input frequency in floating point. + * @param[in] useHTKmethod bool to signal if HTK method is to be + * used for calculation. + * @return Mel transformed frequency in floating point. + **/ + static float MelScale(float freq, + bool useHTKMethod = true); + + /** + * @brief Inverse Mel transform - convert MEL warped frequency + * back to normal frequency. + * @param[in] freq Mel frequency in floating point. + * @param[in] useHTKmethod bool to signal if HTK method is to be + * used for calculation. + * @return Real world frequency in floating point. + **/ + static float InverseMelScale(float melFreq, + bool useHTKMethod = true); + + /** + * @brief Populates MEL energies after applying the MEL filter + * bank weights and adding them up to be placed into + * bins, according to the filter bank's first and last + * indices (pre-computed for each filter bank element + * by _CreateMelFilterBank function). + * @param[in] fftVec Vector populated with FFT magnitudes. + * @param[in] melFilterBank 2D Vector with filter bank weights. + * @param[in] filterBankFilterFirst Vector containing the first indices of filter bank + * to be used for each bin. + * @param[in] filterBankFilterLast Vector containing the last indices of filter bank + * to be used for each bin. + * @param[out] melEnergies Pre-allocated vector of MEL energies to be + * populated. + * @return true if successful, false otherwise. + */ + virtual bool ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies); + + /** + * @brief Converts the Mel energies for logarithmic scale. + * @param[in,out] melEnergies 1D vector of Mel energies. + **/ + virtual void ConvertToLogarithmicScale(std::vector& melEnergies); + + /** + * @brief Create a matrix used to calculate Discrete Cosine + * Transform. + * @param[in] inputLength Input length of the buffer on which + * DCT will be performed. + * @param[in] coefficientCount Total coefficients per input length. + * @return 1D vector with inputLength x coefficientCount elements + * populated with DCT coefficients. + */ + virtual std::vector CreateDCTMatrix( + int32_t inputLength, + int32_t coefficientCount); + + /** + * @brief Given the low and high Mel values, get the normaliser + * for weights to be applied when populating the filter + * bank. + * @param[in] leftMel Low Mel frequency value. + * @param[in] rightMel High Mel frequency value. + * @param[in] useHTKMethod bool to signal if HTK method is to be + * used for calculation. + * @return Value to use for normalizing. + */ + virtual float GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + bool useHTKMethod); + + private: + MfccParams _m_params; + std::vector _m_frame; + std::vector _m_buffer; + std::vector _m_melEnergies; + std::vector _m_windowFunc; + std::vector> _m_melFilterBank; + std::vector _m_dctMatrix; + std::vector _m_filterBankFilterFirst; + std::vector _m_filterBankFilterLast; + bool _m_filterBankInitialised; + arm::app::math::FftInstance _m_fftInstance; + + /** + * @brief Initialises the filter banks and the DCT matrix. **/ + void _InitMelFilterBank(); + + /** + * @brief Signals whether the instance of MFCC has had its + * required buffers initialised. + * @return true if initialised, false otherwise. + **/ + bool _IsMelFilterBankInited(); + + /** + * @brief Create mel filter banks for MFCC calculation. + * @return 2D vector of floats. + **/ + std::vector> _CreateMelFilterBank(); + + /** + * @brief Computes and populates internal memeber buffers used + * in MFCC feature calculation + * @param[in] audioData 1D vector of 16-bit audio data. + */ + void _MfccComputePreFeature(const std::vector& audioData); + + /** @brief Computes the magnitude from an interleaved complex array. */ + void _ConvertToPowerSpectrum(); + + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* MFCC_HPP */ \ No newline at end of file diff --git a/source/application/main/include/PlatformMath.hpp b/source/application/main/include/PlatformMath.hpp new file mode 100644 index 0000000..45e6a9e --- /dev/null +++ b/source/application/main/include/PlatformMath.hpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef PLATFORM_MATH_HPP +#define PLATFORM_MATH_HPP + +#include "hal.h" + +/* See if ARM DSP functions can be used. */ +#if PLATFORM_HAL == PLATFORM_CORTEX_M_BAREMETAL + #if defined(__DSP_PRESENT) && (__DSP_PRESENT == 1U) + + #define ARM_DSP_AVAILABLE (1U) + #include "arm_math.h" + #define M_PI (PI) + + #endif /* defined(__DSP_PRESENT) && (__DSP_PRESENT == 1U) */ +#endif /* PLATFORM_HAL == PLATFORM_CORTEX_M_BAREMETAL */ + +#include + +namespace arm { +namespace app { +namespace math { + + struct FftInstance { +#if ARM_DSP_AVAILABLE + arm_rfft_fast_instance_f32 instance; +#endif + bool initialised = false; + }; + + /* Class to provide Math functions like FFT, mean, stddev etc. + * This will allow other classes, functions to be independent of + * #if definition checks and provide a cleaner API. Also, it will + * consolidate all arm math functions used in one place and make + * them easier to test. */ + class MathUtils { + + public: + /** + * @brief Get the cosine value of the argument in floating point. + * @param[in] radians Angle in radians. + * @return Cosine value (floating point). + */ + static float CosineF32(float radians); + + /** + * @brief Get the square root of the argument in floating point. + * @param[in] input Value to compute square root of. + * @return Square root (floating point) value. + */ + static float SqrtF32(float input); + + /** + * @brief Gets the mean of a floating point array of elements. + * @param[in] ptrSrc Pointer to the first element. + * @param[in] srcLen Number of elements in the array/vector. + * @return Average value. + */ + static float MeanF32(float* ptrSrc, uint32_t srcLen); + + /** + * @brief Gets the standard deviation of a floating point array + * of elements. + * @param[in] ptrSrc Pointer to the first element. + * @param[in] srcLen Number of elements in the array/vector. + * @param[in] mean Pre-computed mean value. + * @return Standard deviation value. + */ + static float StdDevF32(float* ptrSrc, uint32_t srcLen, + float mean); + + /** + * @brief Initialises the internal FFT structures (if available + * for the platform). This function should be called + * prior to Fft32 function call if built with ARM DSP functions. + * @param[in] fftLen Requested length of the FFT. + * @param[in] fftInstance FFT instance struct to use. + * @return true if successful, false otherwise. + */ + static bool FftInitF32(const uint16_t fftLen, arm::app::math::FftInstance& fftInstance); + + /** + * @brief Computes the FFT for the input vector. + * @param[in] input Floating point vector of input elements + * @param[out] fftOutput Output buffer to be populated by computed FFTs. + * @param[in] fftInstance FFT instance struct to use. + */ + static void FftF32(std::vector& input, + std::vector& fftOutput, + arm::app::math::FftInstance& fftInstance); + + /** + * @brief Computes the natural logarithms of input floating point + * vector + * @param[in] input Floating point input vector + * @param[out] output Pre-allocated buffer to be populated with + * natural log values of each input element. + */ + static void VecLogarithmF32(std::vector & input, + std::vector & output); + + /** + * @brief Computes the dot product of two 1D floating point + * vectors. + * result = sum(srcA[0]*srcB[0] + srcA[1]*srcB[1] + ..) + * @param[in] srcPtrA Pointer to the first element of first + * array. + * @param[in] srcPtrB Pointer to the first element of second + * array. + * @param[in] srcLen Number of elements in the array/vector. + * @return Dot product. + */ + static float DotProductF32(float* srcPtrA, float* srcPtrB, + const uint32_t srcLen); + + /** + * @brief Computes the squared magnitude of floating point + * complex number array. + * @param[in] ptrSrc Pointer to the first element of input + * array. + * @param[in] srcLen Number of elements in the array/vector. + * @param[out] ptrDst Output buffer to be populated. + * @param[in] dstLen Output buffer len (for sanity check only). + * @return true if successful, false otherwise. + */ + static bool ComplexMagnitudeSquaredF32(float* ptrSrc, + const uint32_t srcLen, + float* ptrDst, + const uint32_t dstLen); + + }; +} /* namespace math */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* PLATFORM_MATH_HPP */ \ No newline at end of file diff --git a/source/application/main/include/Profiler.hpp b/source/application/main/include/Profiler.hpp new file mode 100644 index 0000000..b16a63b --- /dev/null +++ b/source/application/main/include/Profiler.hpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef APP_PROFILER_HPP +#define APP_PROFILER_HPP + +#include "hal.h" + +#include +#include +#include + +namespace arm { +namespace app { + + /** A single profiling unit definition. */ + struct ProfilingUnit { + uint64_t npuCycles = 0; + uint64_t activeNpuCycles = 0; + uint64_t cpuCycles = 0; + time_t time = 0; + }; + + /* A collection of profiling units. */ + using ProfilingSeries = std::vector; + + /* A map for string identifiable profiling series. */ + using ProfilingMap = std::map; + + /** + * @brief A very simple profiler example using the platform timer + * implementation. + */ + class Profiler { + public: + /** + * @brief Constructor for profiler. + * @param[in] platform Pointer to a valid, initialised hal platform. + * @param[in] name A friendly name for this profiler. + **/ + Profiler(hal_platform* platform, const char* name); + + /** Block the default constructor. */ + Profiler() = delete; + + /** Default destructor. */ + ~Profiler() = default; + + /** @brief Start profiling => get starting time-stamp. */ + bool StartProfiling(const char* name = nullptr); + + /** @brief Stop profiling => get the ending time-stamp. */ + bool StopProfiling(); + + /** @brief Stops the profiling and internally resets the + * platform timers. */ + bool StopProfilingAndReset(); + + /** @brief Reset the platform timers. */ + void Reset(); + + /** + * @brief Gets the results as string and resets the profiler. + * @returns Result string. + **/ + std::string GetResultsAndReset(); + + /** @brief Set the profiler name. */ + void SetName(const char* str); + + private: + ProfilingMap _m_series; /* Profiling series map. */ + time_counter _m_tstampSt; /* Container for a current starting timestamp. */ + time_counter _m_tstampEnd; /* Container for a current ending timestamp. */ + hal_platform * _m_pPlatform = nullptr; /* Platform pointer - to get the timer. */ + + bool _m_started = false; /* Indicates profiler has been started. */ + + std::string _m_name; /* Name given to this profiler. */ + + /** + * @brief Appends the profiling unit computed by the "start" and + * "end" timestamps to the profiling series identified by + * the name provided. + * @param[in] start Starting time-stamp. + * @param[in] end Ending time-stamp. + * @param[in] name Name for the profiling unit series to be + * appended to. + **/ + void _AddProfilingUnit(time_counter start, time_counter end, + const std::string& name); + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* APP_PROFILER_HPP */ diff --git a/source/application/main/include/UseCaseCommonUtils.hpp b/source/application/main/include/UseCaseCommonUtils.hpp new file mode 100644 index 0000000..02200e8 --- /dev/null +++ b/source/application/main/include/UseCaseCommonUtils.hpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef USECASE_COMMON_UTILS_HPP +#define USECASE_COMMON_UTILS_HPP + +#include "hal.h" +#include "Model.hpp" +#include "AppContext.hpp" +#include "Profiler.hpp" + +/* Helper macro to convert RGB888 to RGB565 format. */ +#define RGB888_TO_RGB565(R8,G8,B8) ((((R8>>3) & 0x1F) << 11) | \ + (((G8>>2) & 0x3F) << 5) | \ + ((B8>>3) & 0x1F)) + +constexpr uint16_t COLOR_BLACK = 0; +constexpr uint16_t COLOR_GREEN = RGB888_TO_RGB565( 0, 255, 0); // 2016; +constexpr uint16_t COLOR_YELLOW = RGB888_TO_RGB565(255, 255, 0); // 65504; + +namespace arm { +namespace app { + + /** + * @brief Run inference using given model + * object. If profiling is enabled, it will log the + * statistics too. + * @param[in] platform Reference to the hal platform object. + * @param[in] model Reference to the initialised model. + * @return true if inference succeeds, false otherwise. + **/ + bool RunInference(hal_platform& platform, arm::app::Model& model); + + /** + * @brief Read input and return as an integer. + * @param[in] platform Reference to the hal platform object. + * @param[in] model Reference to the initialised model. + * @return Integer value corresponding to the user input. + **/ + int ReadUserInputAsInt(hal_platform& platform); + +#if VERIFY_TEST_OUTPUT + /** + * @brief Helper function to dump a tensor to stdout + * @param[in] tensor tensor to be dumped + * @param[in] lineBreakForNumElements number of elements + * after which line break will be added. + **/ + void DumpTensor(TfLiteTensor* tensor, + const size_t lineBreakForNumElements = 16); +#endif /* VERIFY_TEST_OUTPUT */ + + /** + * @brief List the files baked in the application. + * @param[in] ctx Reference to the application context. + * @return true or false based on event being handled. + **/ + bool ListFilesHandler(ApplicationContext& ctx); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* USECASE_COMMON_UTILS_HPP */ \ No newline at end of file diff --git a/source/application/tensorflow-lite-micro/Model.cc b/source/application/tensorflow-lite-micro/Model.cc new file mode 100644 index 0000000..0775467 --- /dev/null +++ b/source/application/tensorflow-lite-micro/Model.cc @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Model.hpp" + +#include "hal.h" + +#include + +/* Initialise the model */ +arm::app::Model::~Model() +{ + if (this->_m_pInterpreter) { + delete this->_m_pInterpreter; + } + + /** + * No clean-up function available for allocator in TensorFlow Lite Micro yet. + **/ +} + +arm::app::Model::Model() : + _m_inited (false), + _m_type(kTfLiteNoType) +{ + this->_m_pErrorReporter = &this->_m_uErrorReporter; +} + +bool arm::app::Model::Init(tflite::MicroAllocator* allocator) +{ + /* Following tf lite micro example: + * Map the model into a usable data structure. This doesn't involve any + * copying or parsing, it's a very lightweight operation. */ + const uint8_t* model_addr = ModelPointer(); + debug("loading model from @ 0x%p\n", model_addr); + this->_m_pModel = ::tflite::GetModel(model_addr); + + if (this->_m_pModel->version() != TFLITE_SCHEMA_VERSION) { + this->_m_pErrorReporter->Report( + "[ERROR] model's schema version %d is not equal " + "to supported version %d.", + this->_m_pModel->version(), TFLITE_SCHEMA_VERSION); + return false; + } + + /* Pull in only the operation implementations we need. + * This relies on a complete list of all the ops needed by this graph. + * An easier approach is to just use the AllOpsResolver, but this will + * incur some penalty in code space for op implementations that are not + * needed by this graph. + * static ::tflite::ops::micro::AllOpsResolver resolver; */ + /* NOLINTNEXTLINE(runtime-global-variables) */ + debug("loading op resolver\n"); + + this->EnlistOperations(); + + /* Create allocator instance, if it doesn't exist */ + this->_m_pAllocator = allocator; + if (!this->_m_pAllocator) { + /* Create an allocator instance */ + info("Creating allocator using tensor arena in %s\n", + ACTIVATION_BUF_SECTION_NAME); + + this->_m_pAllocator = tflite::MicroAllocator::Create( + this->GetTensorArena(), + this->GetActivationBufferSize(), + this->_m_pErrorReporter); + + if (!this->_m_pAllocator) { + printf_err("Failed to create allocator\n"); + return false; + } + debug("Created new allocator @ 0x%p\n", this->_m_pAllocator); + } else { + debug("Using existing allocator @ 0x%p\n", this->_m_pAllocator); + } + + this->_m_pInterpreter = new ::tflite::MicroInterpreter( + this->_m_pModel, this->GetOpResolver(), + this->_m_pAllocator, this->_m_pErrorReporter); + + if (!this->_m_pInterpreter) { + printf_err("Failed to allocate interpreter\n"); + return false; + } + + /* Allocate memory from the tensor_arena for the model's tensors. */ + info("Allocating tensors\n"); + TfLiteStatus allocate_status = this->_m_pInterpreter->AllocateTensors(); + + if (allocate_status != kTfLiteOk) { + this->_m_pErrorReporter->Report("[ERROR] allocateTensors() failed"); + printf_err("tensor allocation failed!\n"); + delete this->_m_pInterpreter; + return false; + } + + /* Get information about the memory area to use for the model's input. */ + this->_m_input.resize(this->GetNumInputs()); + for (size_t inIndex = 0; inIndex < this->GetNumInputs(); inIndex++) + this->_m_input[inIndex] = this->_m_pInterpreter->input(inIndex); + + this->_m_output.resize(this->GetNumOutputs()); + for (size_t outIndex = 0; outIndex < this->GetNumOutputs(); outIndex++) + this->_m_output[outIndex] = this->_m_pInterpreter->output(outIndex); + + if (this->_m_input.empty() || this->_m_output.empty()) { + printf_err("failed to get tensors\n"); + return false; + } else { + this->_m_type = this->_m_input[0]->type; /* Input 0 should be the main input */ + + /* Clear the input & output tensors */ + for (size_t inIndex = 0; inIndex < this->GetNumInputs(); inIndex++) { + std::memset(this->_m_input[inIndex]->data.data, 0, this->_m_input[inIndex]->bytes); + } + for (size_t outIndex = 0; outIndex < this->GetNumOutputs(); outIndex++) { + std::memset(this->_m_output[outIndex]->data.data, 0, this->_m_output[outIndex]->bytes); + } + + this->LogInterpreterInfo(); + } + + this->_m_inited = true; + return true; +} + +tflite::MicroAllocator* arm::app::Model::GetAllocator() +{ + if (this->IsInited()) { + return this->_m_pAllocator; + } + return nullptr; +} + +void arm::app::Model::LogTensorInfo(TfLiteTensor* tensor) +{ + if (!tensor) { + printf_err("Invalid tensor\n"); + assert(tensor); + return; + } + + debug("\ttensor is assigned to 0x%p\n", tensor); + info("\ttensor type is %s\n", TfLiteTypeGetName(tensor->type)); + info("\ttensor occupies %u bytes with dimensions\n", + (uint32_t)tensor->bytes); + for (int i = 0 ; i < tensor->dims->size; ++i) { + info ("\t\t%d: %3d\n", i, tensor->dims->data[i]); + } + + TfLiteQuantization quant = tensor->quantization; + if (kTfLiteAffineQuantization == quant.type) { + auto* quantParams = (TfLiteAffineQuantization*)quant.params; + info("Quant dimension: %u\n", quantParams->quantized_dimension); + for (int i = 0; i < quantParams->scale->size; ++i) { + info("Scale[%d] = %f\n", i, quantParams->scale->data[i]); + } + for (int i = 0; i < quantParams->zero_point->size; ++i) { + info("ZeroPoint[%d] = %d\n", i, quantParams->zero_point->data[i]); + } + } +} + +void arm::app::Model::LogInterpreterInfo() +{ + if (!this->_m_pInterpreter) { + printf_err("Invalid interpreter\n"); + return; + } + + info("Model INPUT tensors: \n"); + for (auto input : this->_m_input) { + this->LogTensorInfo(input); + } + + info("Model OUTPUT tensors: \n"); + for (auto output : this->_m_output) { + this->LogTensorInfo(output); + } + + info("Activation buffer (a.k.a tensor arena) size used: %zu\n", + this->_m_pInterpreter->arena_used_bytes()); + + const uint32_t nOperators = this->_m_pInterpreter->operators_size(); + info("Number of operators: %u\n", nOperators); + + /* For each operator, display registration information */ + for (uint32_t i = 0 ; i < nOperators; ++i) { + const tflite::NodeAndRegistration nodeReg = + this->_m_pInterpreter->node_and_registration(i); + const TfLiteRegistration* reg = nodeReg.registration; + std::string opName{""}; + + if (reg) { + if (tflite::BuiltinOperator_CUSTOM == reg->builtin_code) { + opName = std::string(reg->custom_name); + } else { + opName = std::string(EnumNameBuiltinOperator( + tflite::BuiltinOperator(reg->builtin_code))); + } + } + info("\tOperator %u: %s\n", i, opName.c_str()); + } +} + +bool arm::app::Model::IsInited() const +{ + return this->_m_inited; +} + +bool arm::app::Model::IsDataSigned() const +{ + return this->GetType() == kTfLiteInt8; +} + +bool arm::app::Model::RunInference() +{ + bool inference_state = false; + if (this->_m_pModel && this->_m_pInterpreter) { + if (kTfLiteOk != this->_m_pInterpreter->Invoke()) { + printf_err("Invoke failed.\n"); + } else { + inference_state = true; + } + } else { + printf_err("Error: No interpreter!\n"); + } + return inference_state; +} + +TfLiteTensor* arm::app::Model::GetInputTensor(size_t index) const +{ + if (index < this->GetNumInputs()) { + return this->_m_input.at(index); + } + return nullptr; +} + +TfLiteTensor* arm::app::Model::GetOutputTensor(size_t index) const +{ + if (index < this->GetNumOutputs()) { + return this->_m_output.at(index); + } + return nullptr; +} + +size_t arm::app::Model::GetNumInputs() const +{ + if (this->_m_pModel && this->_m_pInterpreter) { + return this->_m_pInterpreter->inputs_size(); + } + return 0; +} + +size_t arm::app::Model::GetNumOutputs() const +{ + if (this->_m_pModel && this->_m_pInterpreter) { + return this->_m_pInterpreter->outputs_size(); + } + return 0; +} + + +TfLiteType arm::app::Model::GetType() const +{ + return this->_m_type; +} + +TfLiteIntArray* arm::app::Model::GetInputShape(size_t index) const +{ + if (index < this->GetNumInputs()) { + return this->_m_input.at(index)->dims; + } + return nullptr; +} + +TfLiteIntArray* arm::app::Model::GetOutputShape(size_t index) const +{ + if (index < this->GetNumOutputs()) { + return this->_m_output.at(index)->dims; + } + return nullptr; +} + +bool arm::app::Model::ShowModelInfoHandler() +{ + if (!this->IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + PrintTensorFlowVersion(); + info("Model info:\n"); + this->LogInterpreterInfo(); + +#if defined(ARM_NPU) + info("Use of Arm uNPU is enabled\n"); +#else /* ARM_NPU */ + info("Use of Arm uNPU is disabled\n"); +#endif /* ARM_NPU */ + + return true; +} +namespace arm { +namespace app { + static uint8_t _tensor_arena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE; +} /* namespace app */ +} /* namespace arm */ + +size_t arm::app::Model::GetActivationBufferSize() +{ + return ACTIVATION_BUF_SZ; +} + +uint8_t *arm::app::Model::GetTensorArena() +{ + return _tensor_arena; +} \ No newline at end of file diff --git a/source/application/tensorflow-lite-micro/TensorFlowLiteMicro.cc b/source/application/tensorflow-lite-micro/TensorFlowLiteMicro.cc new file mode 100644 index 0000000..ce36a8f --- /dev/null +++ b/source/application/tensorflow-lite-micro/TensorFlowLiteMicro.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "TensorFlowLiteMicro.hpp" + +#include "hal.h" + +void PrintTensorFlowVersion() +{ + info("uTFL version: %u.%u.%u\n", TF_MAJOR_VERSION, TF_MINOR_VERSION, + TF_PATCH_VERSION); +} + +arm::app::QuantParams arm::app::GetTensorQuantParams(TfLiteTensor* tensor) +{ + arm::app::QuantParams params; + if (kTfLiteAffineQuantization == tensor->quantization.type) { + auto* quantParams = (TfLiteAffineQuantization*) (tensor->quantization.params); + if (quantParams && 0 == quantParams->quantized_dimension) { + if (quantParams->scale->size) { + params.scale = quantParams->scale->data[0]; + } + if (quantParams->zero_point->size) { + params.offset = quantParams->zero_point->data[0]; + } + } else if (tensor->params.scale != 0.0) { + /* Legacy tensorflow quantisation parameters */ + params.scale = tensor->params.scale; + params.offset = tensor->params.zero_point; + } + } + return params; +} + diff --git a/source/application/tensorflow-lite-micro/include/BufAttributes.hpp b/source/application/tensorflow-lite-micro/include/BufAttributes.hpp new file mode 100644 index 0000000..126172b --- /dev/null +++ b/source/application/tensorflow-lite-micro/include/BufAttributes.hpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef BUF_ATTRIBUTES_HPP +#define BUF_ATTRIBUTES_HPP + +#ifdef __has_attribute +#define HAVE_ATTRIBUTE(x) __has_attribute(x) +#else /* __has_attribute */ +#define HAVE_ATTRIBUTE(x) 0 +#endif /* __has_attribute */ + +#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) + +/* We want all buffers/sections to be aligned to 16 byte. */ +#define ALIGNMENT_REQ aligned(16) + +/* Model data section name. */ +#define MODEL_SECTION section("nn_model") + +/* Label section name */ +#define LABEL_SECTION section("labels") + +#ifndef ACTIVATION_BUF_SZ + #warning "ACTIVATION_BUF_SZ needs to be defined. Using default value" + #define ACTIVATION_BUF_SZ 0x00200000 +#endif /* ACTIVATION_BUF_SZ */ + +#ifndef ACTIVATION_BUF_SRAM_SZ + #warning "ACTIVATION_BUF_SRAM_SZ needs to be defined. Using default value = 0" + #define ACTIVATION_BUF_SRAM_SZ 0x00000000 +#endif /* ACTIVATION_BUF_SRAM_SZ */ + +/** + * Activation buffer aka tensor arena section name + * We have to place the tensor arena in different region based on its size. + * If it fits in SRAM, we place it there, and also mark it by giving it a + * different section name. The scatter file places the ZI data in DDR and + * the uninitialised region in the SRAM. + **/ +#define ACTIVATION_BUF_SECTION_SRAM section(".bss.NoInit.activation_buf") +#define ACTIVATION_BUF_SECTION_DRAM section("activation_buf") + +#if ACTIVATION_BUF_SZ > ACTIVATION_BUF_SRAM_SZ /* Will buffer not fit in SRAM? */ + #define ACTIVATION_BUF_SECTION ACTIVATION_BUF_SECTION_DRAM + #define ACTIVATION_BUF_SECTION_NAME ("DDR") +#else /* ACTIVATION_BUF_SZ > 0x00200000 */ + #define ACTIVATION_BUF_SECTION ACTIVATION_BUF_SECTION_SRAM + #define ACTIVATION_BUF_SECTION_NAME ("SRAM") +#endif /* ACTIVATION_BUF_SZ > 0x00200000 */ + +/* IFM section name. */ +#define IFM_BUF_SECTION section("ifm") + +/* Form the attributes, alignment is mandatory. */ +#define MAKE_ATTRIBUTE(x) __attribute__((ALIGNMENT_REQ, x)) +#define MODEL_TFLITE_ATTRIBUTE MAKE_ATTRIBUTE(MODEL_SECTION) +#define ACTIVATION_BUF_ATTRIBUTE MAKE_ATTRIBUTE(ACTIVATION_BUF_SECTION) +#define IFM_BUF_ATTRIBUTE MAKE_ATTRIBUTE(IFM_BUF_SECTION) +#define LABELS_ATTRIBUTE MAKE_ATTRIBUTE(LABEL_SECTION) + +#else /* HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) */ + +#define MODEL_TFLITE_ATTRIBUTE +#define ACTIVATION_BUF_ATTRIBUTE +#define IFM_BUF_ATTRIBUTE +#define LABELS_ATTRIBUTE + +#endif /* HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) */ + +#endif /* BUF_ATTRIBUTES_HPP */ \ No newline at end of file diff --git a/source/application/tensorflow-lite-micro/include/Model.hpp b/source/application/tensorflow-lite-micro/include/Model.hpp new file mode 100644 index 0000000..70cf9ca --- /dev/null +++ b/source/application/tensorflow-lite-micro/include/Model.hpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MODEL_HPP +#define MODEL_HPP + +#include "TensorFlowLiteMicro.hpp" +#include "BufAttributes.hpp" + +#include + +namespace arm { +namespace app { + + /** + * @brief NN model class wrapping the underlying TensorFlow-Lite-Micro API. + */ + class Model { + public: + /** @brief Constructor. */ + Model(); + + /** @brief Destructor. */ + ~Model(); + + /** @brief Gets the pointer to the model's input tensor at given input index. */ + TfLiteTensor* GetInputTensor(size_t index) const; + + /** @brief Gets the pointer to the model's output tensor at given output index. */ + TfLiteTensor* GetOutputTensor(size_t index) const; + + /** @brief Gets the model's data type. */ + TfLiteType GetType() const; + + /** @brief Gets the pointer to the model's input shape. */ + TfLiteIntArray* GetInputShape(size_t index) const; + + /** @brief Gets the pointer to the model's output shape at given output index. */ + TfLiteIntArray* GetOutputShape(size_t index) const; + + /** @brief Gets the number of input tensors the model has. */ + size_t GetNumInputs() const; + + /** @brief Gets the number of output tensors the model has. */ + size_t GetNumOutputs() const; + + /** @brief Logs the tensor information to stdout. */ + void LogTensorInfo(TfLiteTensor* tensor); + + /** @brief Logs the interpreter information to stdout. */ + void LogInterpreterInfo(); + + /** @brief Initialise the model class object. + * @param[in] allocator Optional: a pre-initialised micro allocator pointer, + * if available. If supplied, this allocator will be used + * to create the interpreter instance. + * @return true if initialisation succeeds, false otherwise. + **/ + bool Init(tflite::MicroAllocator* allocator = nullptr); + + /** + * @brief Gets the allocator pointer for this instance. + * @return Pointer to a tflite::MicroAllocator object, if + * available; nullptr otherwise. + **/ + tflite::MicroAllocator* GetAllocator(); + + /** @brief Checks if this object has been initialised. */ + bool IsInited() const; + + /** @brief Checks if the model uses signed data. */ + bool IsDataSigned() const; + + /** @brief Runs the inference (invokes the interpreter). */ + bool RunInference(); + + /** @brief Model information handler common to all models. + * @return true or false based on execution success. + **/ + bool ShowModelInfoHandler(); + + /** @brief Gets a pointer to the tensor arena. */ + uint8_t* GetTensorArena(); + + protected: + /** @brief Gets the pointer to the NN model data array. + * @return Pointer of uint8_t type. + **/ + virtual const uint8_t* ModelPointer() = 0; + + /** @brief Gets the model size. + * @return size_t, size in bytes. + **/ + virtual size_t ModelSize() = 0; + + /** + * @brief Gets the op resolver for the model instance. + * @return const reference to a tflite::MicroOpResolver object. + **/ + virtual const tflite::MicroOpResolver& GetOpResolver() = 0; + + /** + * @brief Add all the operators required for the given model. + * Implementation of this should come from the use case. + * @return true is ops are successfully added, false otherwise. + **/ + virtual bool EnlistOperations() = 0; + + /** @brief Gets the total size of tensor arena available for use. */ + size_t GetActivationBufferSize(); + + private: + tflite::MicroErrorReporter _m_uErrorReporter; /* Error reporter object. */ + tflite::ErrorReporter* _m_pErrorReporter = nullptr; /* Pointer to the error reporter. */ + const tflite::Model* _m_pModel = nullptr; /* Tflite model pointer. */ + tflite::MicroInterpreter* _m_pInterpreter = nullptr; /* Tflite interpreter. */ + tflite::MicroAllocator* _m_pAllocator = nullptr; /* Tflite micro allocator. */ + bool _m_inited = false; /* Indicates whether this object has been initialised. */ + + std::vector _m_input = {}; /* Model's input tensor pointers. */ + std::vector _m_output = {}; /* Model's output tensor pointers. */ + TfLiteType _m_type = kTfLiteNoType;/* Model's data type. */ + + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* MODEL_HPP */ diff --git a/source/application/tensorflow-lite-micro/include/TensorFlowLiteMicro.hpp b/source/application/tensorflow-lite-micro/include/TensorFlowLiteMicro.hpp new file mode 100644 index 0000000..677b4ba --- /dev/null +++ b/source/application/tensorflow-lite-micro/include/TensorFlowLiteMicro.hpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TENSORFLOW_LITE_MICRO_LOCAL_HPP +#define TENSORFLOW_LITE_MICRO_LOCAL_HPP + +/* We include all our TensorFlow Lite Micro headers here */ + +/** + * TensorFlow Lite Micro sources can generate a lot of warnings from the usage + * of a single macro (TF_LITE_REMOVE_VIRTUAL_DELETE). Suppress the known ones + * here to prevent them from masking warnings that might be generated by our + * application sources. + */ +#if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050) + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wunused-parameter" + #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" + #include "tensorflow/lite/micro/micro_interpreter.h" + #include "tensorflow/lite/micro/micro_error_reporter.h" + #include "tensorflow/lite/micro/all_ops_resolver.h" + #pragma clang diagnostic pop +#elif defined(__GNUC__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wunused-parameter" + #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" + #include "tensorflow/lite/micro/micro_interpreter.h" + #include "tensorflow/lite/micro/micro_error_reporter.h" + #include "tensorflow/lite/micro/all_ops_resolver.h" + #pragma GCC diagnostic pop +#else + #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" + #include "tensorflow/lite/micro/micro_interpreter.h" + #include "tensorflow/lite/micro/micro_error_reporter.h" + #include "tensorflow/lite/micro/all_ops_resolver.h" +#endif + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/micro_ops.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/version.h" + +#if defined (TESTS) + #include "tensorflow/lite/micro/test_helpers.h" +#endif /* defined (TESTS) */ + +namespace arm { +namespace app { + + struct QuantParams { + float scale = 1.0; + int offset = 0; + }; + + QuantParams GetTensorQuantParams(TfLiteTensor* tensor); + +} /* namespace app */ +} /* namespace arm */ + +/** + * @brief Prints the tensor flow version in use to stdout. + */ +void PrintTensorFlowVersion(); + +#endif /* TENSORFLOW_LITE_MICRO_LOCAL_HPP */ diff --git a/source/use_case/ad/include/AdMelSpectrogram.hpp b/source/use_case/ad/include/AdMelSpectrogram.hpp new file mode 100644 index 0000000..cf8a1d4 --- /dev/null +++ b/source/use_case/ad/include/AdMelSpectrogram.hpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ADMELSPECTROGRAM_HPP +#define ADMELSPECTROGRAM_HPP + +#include "MelSpectrogram.hpp" + +namespace arm { +namespace app { +namespace audio { + + /* Class to provide anomaly detection specific Mel Spectrogram calculation requirements */ + class AdMelSpectrogram : public MelSpectrogram { + + public: + static constexpr uint32_t ms_defaultSamplingFreq = 16000; + static constexpr uint32_t ms_defaultNumFbankBins = 64; + static constexpr uint32_t ms_defaultMelLoFreq = 0; + static constexpr uint32_t ms_defaultMelHiFreq = 8000; + static constexpr bool ms_defaultUseHtkMethod = false; + + explicit AdMelSpectrogram(const size_t frameLen) + : MelSpectrogram(MelSpecParams( + ms_defaultSamplingFreq, ms_defaultNumFbankBins, + ms_defaultMelLoFreq, ms_defaultMelHiFreq, + frameLen, ms_defaultUseHtkMethod)) + {} + + AdMelSpectrogram() = delete; + ~AdMelSpectrogram() = default; + + protected: + + /** + * @brief Overrides base class implementation of this function. + * @param[in] fftVec Vector populated with FFT magnitudes + * @param[in] melFilterBank 2D Vector with filter bank weights + * @param[in] filterBankFilterFirst Vector containing the first indices of filter bank + * to be used for each bin. + * @param[in] filterBankFilterLast Vector containing the last indices of filter bank + * to be used for each bin. + * @param[out] melEnergies Pre-allocated vector of MEL energies to be + * populated. + * @return true if successful, false otherwise + */ + virtual bool ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) override; + + /** + * @brief Override for the base class implementation convert mel + * energies to logarithmic scale. The difference from + * default behaviour is that the power is converted to dB + * and subsequently clamped. + * @param[in/out] melEnergies - 1D vector of Mel energies + **/ + virtual void ConvertToLogarithmicScale(std::vector& melEnergies) override; + + /** + * @brief Given the low and high Mel values, get the normaliser + * for weights to be applied when populating the filter + * bank. Override for the base class implementation. + * @param[in] leftMel - low Mel frequency value + * @param[in] rightMel - high Mel frequency value + * @param[in] useHTKMethod - bool to signal if HTK method is to be + * used for calculation + * @return Return float value to be applied + * when populating the filter bank. + */ + virtual float GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + const bool useHTKMethod) override; + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ADMELSPECTROGRAM_HPP */ diff --git a/source/use_case/ad/include/AdModel.hpp b/source/use_case/ad/include/AdModel.hpp new file mode 100644 index 0000000..2d83455 --- /dev/null +++ b/source/use_case/ad/include/AdModel.hpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AD_MODEL_HPP +#define AD_MODEL_HPP + +#include "Model.hpp" + +extern const int g_FrameLength; +extern const int g_FrameStride; +extern const float g_ScoreThreshold; +extern const float g_TrainingMean; + +namespace arm { +namespace app { + + class AdModel : public Model { + protected: + /** @brief Gets the reference to op resolver interface class */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + /* Maximum number of individual operations that can be enlisted */ + static constexpr int _ms_maxOpCnt = 6; + + /* A mutable op resolver instance */ + tflite::MicroMutableOpResolver<_ms_maxOpCnt> _m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* AD_MODEL_HPP */ diff --git a/source/use_case/ad/include/AdPostProcessing.hpp b/source/use_case/ad/include/AdPostProcessing.hpp new file mode 100644 index 0000000..f3b35a1 --- /dev/null +++ b/source/use_case/ad/include/AdPostProcessing.hpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ADPOSTPROCESSING_HPP +#define ADPOSTPROCESSING_HPP + +#include "TensorFlowLiteMicro.hpp" + +#include + +namespace arm { +namespace app { + + /** @brief Dequantize TensorFlow Lite Micro tensor. + * @param[in] tensor Pointer to the TensorFlow Lite Micro tensor to be dequantized. + * @return Vector with the dequantized tensor values. + **/ + template + std::vector Dequantize(TfLiteTensor* tensor); + + /** + * @brief Calculates the softmax of vector in place. **/ + void Softmax(std::vector& inputVector); + + + /** @brief Given a wav file name return AD model output index. + * @param[in] wavFileName Audio WAV filename. + * File name should be in format __XX_.wav + * where XX is the machine ID e.g. 00, 02, 04 or 06 + * @return AD model output index as 8 bit integer. + **/ + int8_t OutputIndexFromFileName(std::string wavFileName); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ADPOSTPROCESSING_HPP */ diff --git a/source/use_case/ad/include/MelSpectrogram.hpp b/source/use_case/ad/include/MelSpectrogram.hpp new file mode 100644 index 0000000..c1dd61e --- /dev/null +++ b/source/use_case/ad/include/MelSpectrogram.hpp @@ -0,0 +1,233 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MELSPECTROGRAM_HPP +#define MELSPECTROGRAM_HPP + +#include "PlatformMath.hpp" + +#include +#include +#include +#include +#include + +namespace arm { +namespace app { +namespace audio { + + /* Mel Spectrogram consolidated parameters */ + class MelSpecParams { + public: + float m_samplingFreq; + uint32_t m_numFbankBins; + float m_melLoFreq; + float m_melHiFreq; + uint32_t m_frameLen; + uint32_t m_frameLenPadded; + bool m_useHtkMethod; + + /** @brief Constructor */ + MelSpecParams(const float samplingFreq, const uint32_t numFbankBins, + const float melLoFreq, const float melHiFreq, + const uint32_t frameLen, const bool useHtkMethod); + + MelSpecParams() = delete; + ~MelSpecParams() = default; + + /** @brief String representation of parameters */ + std::string Str(); + }; + + /** + * @brief Class for Mel Spectrogram feature extraction. + * Based on https://github.com/ARM-software/ML-KWS-for-MCU/blob/master/Deployment/Source/MFCC/mfcc.cpp + * This class is designed to be generic and self-sufficient but + * certain calculation routines can be overridden to accommodate + * use-case specific requirements. + */ + class MelSpectrogram { + + public: + /** + * @brief Extract Mel Spectrogram for one single small frame of + * audio data e.g. 640 samples. + * @param[in] audioData - Vector of audio samples to calculate + * features for. + * @param[in] trainingMean - Value to subtract from the the computed mel spectrogram, default 0. + * @return Vector of extracted Mel Spectrogram features. + **/ + std::vector ComputeMelSpec(const std::vector& audioData, float trainingMean = 0); + + /** + * @brief Constructor + * @param[in] params - Mel Spectrogram parameters + */ + MelSpectrogram(const MelSpecParams& params); + + MelSpectrogram() = delete; + ~MelSpectrogram() = default; + + /** @brief Initialise */ + void Init(); + + /** + * @brief Extract Mel Spectrogram features and quantise for one single small + * frame of audio data e.g. 640 samples. + * @param[in] audioData - Vector of audio samples to calculate + * features for. + * @param[in] quantScale - quantisation scale. + * @param[in] quantOffset - quantisation offset + * @return Vector of extracted quantised Mel Spectrogram features. + **/ + template + std::vector MelSpecComputeQuant(const std::vector& audioData, + const float quantScale, + const int quantOffset, + float trainingMean = 0) + { + this->ComputeMelSpec(audioData, trainingMean); + float minVal = std::numeric_limits::min(); + float maxVal = std::numeric_limits::max(); + + std::vector melSpecOut(this->_m_params.m_numFbankBins); + const size_t numFbankBins = this->_m_params.m_numFbankBins; + + /* Quantize to T. */ + for (size_t k = 0; k < numFbankBins; ++k) { + auto quantizedEnergy = std::round(((this->_m_melEnergies[k]) / quantScale) + quantOffset); + melSpecOut[k] = static_cast(std::min(std::max(quantizedEnergy, minVal), maxVal)); + } + + return melSpecOut; + } + + /* Constants */ + static constexpr float ms_logStep = /*logf(6.4)*/ 1.8562979903656 / 27.0; + static constexpr float ms_freqStep = 200.0 / 3; + static constexpr float ms_minLogHz = 1000.0; + static constexpr float ms_minLogMel = ms_minLogHz / ms_freqStep; + + protected: + /** + * @brief Project input frequency to Mel Scale. + * @param[in] freq - input frequency in floating point + * @param[in] useHTKmethod - bool to signal if HTK method is to be + * used for calculation + * @return Mel transformed frequency in floating point + **/ + static float MelScale(const float freq, + const bool useHTKMethod = true); + + /** + * @brief Inverse Mel transform - convert MEL warped frequency + * back to normal frequency + * @param[in] freq - Mel frequency in floating point + * @param[in] useHTKmethod - bool to signal if HTK method is to be + * used for calculation + * @return Real world frequency in floating point + **/ + static float InverseMelScale(const float melFreq, + const bool useHTKMethod = true); + + /** + * @brief Populates MEL energies after applying the MEL filter + * bank weights and adding them up to be placed into + * bins, according to the filter bank's first and last + * indices (pre-computed for each filter bank element + * by _CreateMelFilterBank function). + * @param[in] fftVec Vector populated with FFT magnitudes + * @param[in] melFilterBank 2D Vector with filter bank weights + * @param[in] filterBankFilterFirst Vector containing the first indices of filter bank + * to be used for each bin. + * @param[in] filterBankFilterLast Vector containing the last indices of filter bank + * to be used for each bin. + * @param[out] melEnergies Pre-allocated vector of MEL energies to be + * populated. + * @return true if successful, false otherwise + */ + virtual bool ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies); + + /** + * @brief Converts the Mel energies for logarithmic scale + * @param[in/out] melEnergies - 1D vector of Mel energies + **/ + virtual void ConvertToLogarithmicScale(std::vector& melEnergies); + + /** + * @brief Given the low and high Mel values, get the normaliser + * for weights to be applied when populating the filter + * bank. + * @param[in] leftMel - low Mel frequency value + * @param[in] rightMel - high Mel frequency value + * @param[in] useHTKMethod - bool to signal if HTK method is to be + * used for calculation + * @return Return float value to be applied + * when populating the filter bank. + */ + virtual float GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + const bool useHTKMethod); + + private: + MelSpecParams _m_params; + std::vector _m_frame; + std::vector _m_buffer; + std::vector _m_melEnergies; + std::vector _m_windowFunc; + std::vector> _m_melFilterBank; + std::vector _m_filterBankFilterFirst; + std::vector _m_filterBankFilterLast; + bool _m_filterBankInitialised; + arm::app::math::FftInstance _m_fftInstance; + + /** + * @brief Initialises the filter banks. + **/ + void _InitMelFilterBank(); + + /** + * @brief Signals whether the instance of MelSpectrogram has had its + * required buffers initialised + * @return True if initialised, false otherwise + **/ + bool _IsMelFilterBankInited(); + + /** + * @brief Create mel filter banks for Mel Spectrogram calculation. + * @return 2D vector of floats + **/ + std::vector> _CreateMelFilterBank(); + + /** + * @brief Computes the magnitude from an interleaved complex array + **/ + void _ConvertToPowerSpectrum(); + + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + + +#endif /* MELSPECTROGRAM_HPP */ diff --git a/source/use_case/ad/include/UseCaseHandler.hpp b/source/use_case/ad/include/UseCaseHandler.hpp new file mode 100644 index 0000000..b62b36d --- /dev/null +++ b/source/use_case/ad/include/UseCaseHandler.hpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AD_EVT_HANDLER_H +#define AD_EVT_HANDLER_H +#include "AppContext.hpp" + +namespace arm { +namespace app { + /** + * @brief Handles the inference event + * @param[in] ctx pointer to the application context + * @param[in] dataIndex index to the input data to classify + * @param[in] runAll flag to request classification of all the available audio clips + * @return True or false based on execution success + **/ + bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t dataIndex, bool runAll); +} /* namespace app */ +} /* namespace arm */ +#endif /* AD_EVT_HANDLER_H */ \ No newline at end of file diff --git a/source/use_case/ad/src/AdMelSpectrogram.cc b/source/use_case/ad/src/AdMelSpectrogram.cc new file mode 100644 index 0000000..183c05c --- /dev/null +++ b/source/use_case/ad/src/AdMelSpectrogram.cc @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AdMelSpectrogram.hpp" + +#include "PlatformMath.hpp" + +namespace arm { +namespace app { +namespace audio { + + bool AdMelSpectrogram::ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) + { + const size_t numBanks = melEnergies.size(); + + if (numBanks != filterBankFilterFirst.size() || + numBanks != filterBankFilterLast.size()) { + printf_err("unexpected filter bank lengths\n"); + return false; + } + + for (size_t bin = 0; bin < numBanks; ++bin) { + auto filterBankIter = melFilterBank[bin].begin(); + float melEnergy = 1e-10; /* Avoid log of zero at later stages. */ + const int32_t firstIndex = filterBankFilterFirst[bin]; + const int32_t lastIndex = filterBankFilterLast[bin]; + + for (int32_t i = firstIndex; i <= lastIndex; ++i) { + melEnergy += (*filterBankIter++ * fftVec[i]); + } + + melEnergies[bin] = melEnergy; + } + + return true; + } + + void AdMelSpectrogram::ConvertToLogarithmicScale( + std::vector& melEnergies) + { + /* Container for natural logarithms of mel energies */ + std::vector vecLogEnergies(melEnergies.size(), 0.f); + + /* Because we are taking natural logs, we need to multiply by log10(e). + * Also, for wav2letter model, we scale our log10 values by 10 */ + constexpr float multiplier = 10.0 * /* default scalar */ + 0.4342944819032518; /* log10f(std::exp(1.0))*/ + + /* Take log of the whole vector */ + math::MathUtils::VecLogarithmF32(melEnergies, vecLogEnergies); + + /* Scale the log values. */ + for (auto iterM = melEnergies.begin(), iterL = vecLogEnergies.begin(); + iterM != melEnergies.end(); ++iterM, ++iterL) { + + *iterM = *iterL * multiplier; + } + } + + float AdMelSpectrogram::GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + const bool useHTKMethod) + { + /* Slaney normalization for mel weights. */ + return (2.0f / (AdMelSpectrogram::InverseMelScale(rightMel, useHTKMethod) - + AdMelSpectrogram::InverseMelScale(leftMel, useHTKMethod))); + } + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/ad/src/AdModel.cc b/source/use_case/ad/src/AdModel.cc new file mode 100644 index 0000000..148bc98 --- /dev/null +++ b/source/use_case/ad/src/AdModel.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AdModel.hpp" + +#include "hal.h" + +const tflite::MicroOpResolver& arm::app::AdModel::GetOpResolver() +{ + return this->_m_opResolver; +} + +bool arm::app::AdModel::EnlistOperations() +{ + this->_m_opResolver.AddAveragePool2D(); + this->_m_opResolver.AddConv2D(); + this->_m_opResolver.AddDepthwiseConv2D(); + this->_m_opResolver.AddRelu6(); + this->_m_opResolver.AddReshape(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->_m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +extern uint8_t* GetModelPointer(); +const uint8_t* arm::app::AdModel::ModelPointer() +{ + return GetModelPointer(); +} +extern size_t GetModelLen(); +size_t arm::app::AdModel::ModelSize() +{ + return GetModelLen(); +} diff --git a/source/use_case/ad/src/AdPostProcessing.cc b/source/use_case/ad/src/AdPostProcessing.cc new file mode 100644 index 0000000..157784b --- /dev/null +++ b/source/use_case/ad/src/AdPostProcessing.cc @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AdPostProcessing.hpp" + +#include "hal.h" + +#include +#include +#include + +namespace arm { +namespace app { + + template + std::vector Dequantize(TfLiteTensor* tensor) { + + if (tensor == nullptr) { + printf_err("Tensor is null pointer can not dequantize.\n"); + return std::vector(); + } + T* tensorData = tflite::GetTensorData(tensor); + + uint32_t totalOutputSize = 1; + for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){ + totalOutputSize *= tensor->dims->data[inputDim]; + } + + /* For getting the floating point values, we need quantization parameters */ + QuantParams quantParams = GetTensorQuantParams(tensor); + + std::vector dequantizedOutput(totalOutputSize); + + for (size_t i = 0; i < totalOutputSize; ++i) { + dequantizedOutput[i] = quantParams.scale * (tensorData[i] - quantParams.offset); + } + + return dequantizedOutput; + } + + void Softmax(std::vector& inputVector) { + auto start = inputVector.begin(); + auto end = inputVector.end(); + + /* Fix for numerical stability and apply exp. */ + float maxValue = *std::max_element(start, end); + for (auto it = start; it!=end; ++it) { + *it = std::exp((*it) - maxValue); + } + + float sumExp = std::accumulate(start, end, 0.0f); + + for (auto it = start; it!=end; ++it) { + *it = (*it)/sumExp; + } + } + + int8_t OutputIndexFromFileName(std::string wavFileName) { + /* Filename is assumed in the form machine_id_00.wav */ + std::string delimiter = "_"; /* First character used to split the file name up. */ + size_t delimiterStart; + std::string subString; + size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */ + + for (size_t i = 0; i < machineIdxInString; ++i) { + delimiterStart = wavFileName.find(delimiter); + subString = wavFileName.substr(0, delimiterStart); + wavFileName.erase(0, delimiterStart + delimiter.length()); + } + + /* At this point substring should be 00.wav */ + delimiter = "."; /* Second character used to split the file name up. */ + delimiterStart = subString.find(delimiter); + subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString; + + auto is_number = [](const std::string& str) -> bool + { + std::string::const_iterator it = str.begin(); + while (it != str.end() && std::isdigit(*it)) ++it; + return !str.empty() && it == str.end(); + }; + + const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1; + + /* Return corresponding index in the output vector. */ + if (machineIdx == 0) { + return 0; + } else if (machineIdx == 2) { + return 1; + } else if (machineIdx == 4) { + return 2; + } else if (machineIdx == 6) { + return 3; + } else { + printf_err("%d is an invalid machine index \n", machineIdx); + return -1; + } + } + + template std::vector Dequantize(TfLiteTensor* tensor); + template std::vector Dequantize(TfLiteTensor* tensor); +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/ad/src/MainLoop.cc b/source/use_case/ad/src/MainLoop.cc new file mode 100644 index 0000000..5455b43 --- /dev/null +++ b/source/use_case/ad/src/MainLoop.cc @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "hal.h" /* Brings in platform definitions */ +#include "InputFiles.hpp" /* For input data */ +#include "AdModel.hpp" /* Model class for running inference */ +#include "UseCaseCommonUtils.hpp" /* Utils functions */ +#include "UseCaseHandler.hpp" /* Handlers for different user options */ + +enum opcodes +{ + MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector */ + MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index */ + MENU_OPT_RUN_INF_ALL, /* Run inference on all */ + MENU_OPT_SHOW_MODEL_INFO, /* Show model info */ + MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio signals */ +}; + +static void DisplayMenu() +{ + printf("\n\nUser input required\n"); + printf("Enter option number from:\n\n"); + printf(" %u. Classify next audio signal\n", MENU_OPT_RUN_INF_NEXT); + printf(" %u. Classify audio signal at chosen index\n", MENU_OPT_RUN_INF_CHOSEN); + printf(" %u. Run classification on all audio signals\n", MENU_OPT_RUN_INF_ALL); + printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO); + printf(" %u. List audio signals\n\n", MENU_OPT_LIST_AUDIO_CLIPS); + printf(" Choice: "); +} + + +void main_loop(hal_platform& platform) +{ + arm::app::AdModel model; /* Model wrapper object. */ + + /* Load the model. */ + if (!model.Init()) + { + printf_err("failed to initialise model\n"); + return; + } + + /* Instantiate application context. */ + arm::app::ApplicationContext caseContext; + + caseContext.Set("platform", platform); + caseContext.Set("model", model); + caseContext.Set("clipIndex", 0); + caseContext.Set("frameLength", g_FrameLength); + caseContext.Set("frameStride", g_FrameStride); + caseContext.Set("scoreThreshold", g_ScoreThreshold); + caseContext.Set("trainingMean", g_TrainingMean); + + /* Main program loop. */ + bool executionSuccessful = true; + constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; + + /* Loop. */ + do { + int menuOption = MENU_OPT_RUN_INF_NEXT; + if (bUseMenu) { + DisplayMenu(); + menuOption = arm::app::ReadUserInputAsInt(platform); + printf("\n"); + } + switch (menuOption) { + case MENU_OPT_RUN_INF_NEXT: + executionSuccessful = ClassifyVibrationHandler( + caseContext, + caseContext.Get("clipIndex"), + false); + break; + case MENU_OPT_RUN_INF_CHOSEN: { + printf(" Enter the data index [0, %d]: ", + NUMBER_OF_FILES-1); + auto audioIndex = static_cast( + arm::app::ReadUserInputAsInt(platform)); + executionSuccessful = ClassifyVibrationHandler(caseContext, + audioIndex, + false); + break; + } + case MENU_OPT_RUN_INF_ALL: + executionSuccessful = ClassifyVibrationHandler( + caseContext, + caseContext.Get("clipIndex"), + true); + break; + case MENU_OPT_SHOW_MODEL_INFO: + executionSuccessful = model.ShowModelInfoHandler(); + break; + case MENU_OPT_LIST_AUDIO_CLIPS: + executionSuccessful = ListFilesHandler(caseContext); + break; + default: + printf("Incorrect choice, try again."); + break; + } + } while (executionSuccessful && bUseMenu); + info("Main loop terminated.\n"); +} diff --git a/source/use_case/ad/src/MelSpectrogram.cc b/source/use_case/ad/src/MelSpectrogram.cc new file mode 100644 index 0000000..86d57e6 --- /dev/null +++ b/source/use_case/ad/src/MelSpectrogram.cc @@ -0,0 +1,311 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MelSpectrogram.hpp" + +#include "PlatformMath.hpp" + +#include + +namespace arm { +namespace app { +namespace audio { + + MelSpecParams::MelSpecParams( + const float samplingFreq, + const uint32_t numFbankBins, + const float melLoFreq, + const float melHiFreq, + const uint32_t frameLen, + const bool useHtkMethod): + m_samplingFreq(samplingFreq), + m_numFbankBins(numFbankBins), + m_melLoFreq(melLoFreq), + m_melHiFreq(melHiFreq), + m_frameLen(frameLen), + + /* Smallest power of 2 >= frame length. */ + m_frameLenPadded(pow(2, ceil((log(frameLen)/log(2))))), + m_useHtkMethod(useHtkMethod) + {} + + std::string MelSpecParams::Str() + { + char strC[1024]; + snprintf(strC, sizeof(strC) - 1, "\n \ + \n\t Sampling frequency: %f\ + \n\t Number of filter banks: %u\ + \n\t Mel frequency limit (low): %f\ + \n\t Mel frequency limit (high): %f\ + \n\t Frame length: %u\ + \n\t Padded frame length: %u\ + \n\t Using HTK for Mel scale: %s\n", + this->m_samplingFreq, this->m_numFbankBins, this->m_melLoFreq, + this->m_melHiFreq, this->m_frameLen, + this->m_frameLenPadded, this->m_useHtkMethod ? "yes" : "no"); + return std::string{strC}; + } + + MelSpectrogram::MelSpectrogram(const MelSpecParams& params): + _m_params(params), + _m_filterBankInitialised(false) + { + this->_m_buffer = std::vector( + this->_m_params.m_frameLenPadded, 0.0); + this->_m_frame = std::vector( + this->_m_params.m_frameLenPadded, 0.0); + this->_m_melEnergies = std::vector( + this->_m_params.m_numFbankBins, 0.0); + + this->_m_windowFunc = std::vector(this->_m_params.m_frameLen); + const float multiplier = 2 * M_PI / this->_m_params.m_frameLen; + + /* Create window function. */ + for (size_t i = 0; i < this->_m_params.m_frameLen; ++i) { + this->_m_windowFunc[i] = (0.5 - (0.5 * + math::MathUtils::CosineF32(static_cast(i) * multiplier))); + } + + math::MathUtils::FftInitF32(this->_m_params.m_frameLenPadded, this->_m_fftInstance); + debug("Instantiated Mel Spectrogram object: %s\n", this->_m_params.Str().c_str()); + } + + void MelSpectrogram::Init() + { + this->_InitMelFilterBank(); + } + + float MelSpectrogram::MelScale(const float freq, const bool useHTKMethod) + { + if (useHTKMethod) { + return 1127.0f * logf (1.0f + freq / 700.0f); + } else { + /* Slaney formula for mel scale. */ + float mel = freq / ms_freqStep; + + if (freq >= ms_minLogHz) { + mel = ms_minLogMel + logf(freq / ms_minLogHz) / ms_logStep; + } + return mel; + } + } + + float MelSpectrogram::InverseMelScale(const float melFreq, const bool useHTKMethod) + { + if (useHTKMethod) { + return 700.0f * (expf (melFreq / 1127.0f) - 1.0f); + } else { + /* Slaney formula for inverse mel scale. */ + float freq = ms_freqStep * melFreq; + + if (melFreq >= ms_minLogMel) { + freq = ms_minLogHz * expf(ms_logStep * (melFreq - ms_minLogMel)); + } + return freq; + } + } + + bool MelSpectrogram::ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) + { + const size_t numBanks = melEnergies.size(); + + if (numBanks != filterBankFilterFirst.size() || + numBanks != filterBankFilterLast.size()) { + printf_err("unexpected filter bank lengths\n"); + return false; + } + + for (size_t bin = 0; bin < numBanks; ++bin) { + auto filterBankIter = melFilterBank[bin].begin(); + float melEnergy = FLT_MIN; /* Avoid log of zero at later stages */ + int32_t firstIndex = filterBankFilterFirst[bin]; + int32_t lastIndex = filterBankFilterLast[bin]; + + for (int i = firstIndex; i <= lastIndex; ++i) { + float energyRep = math::MathUtils::SqrtF32(fftVec[i]); + melEnergy += (*filterBankIter++ * energyRep); + } + + melEnergies[bin] = melEnergy; + } + + return true; + } + + void MelSpectrogram::ConvertToLogarithmicScale(std::vector& melEnergies) + { + for (size_t bin = 0; bin < melEnergies.size(); ++bin) { + melEnergies[bin] = logf(melEnergies[bin]); + } + } + + void MelSpectrogram::_ConvertToPowerSpectrum() + { + const uint32_t halfDim = this->_m_params.m_frameLenPadded / 2; + + /* Handle this special case. */ + float firstEnergy = this->_m_buffer[0] * this->_m_buffer[0]; + float lastEnergy = this->_m_buffer[1] * this->_m_buffer[1]; + + math::MathUtils::ComplexMagnitudeSquaredF32( + this->_m_buffer.data(), + this->_m_buffer.size(), + this->_m_buffer.data(), + this->_m_buffer.size()/2); + + this->_m_buffer[0] = firstEnergy; + this->_m_buffer[halfDim] = lastEnergy; + } + + float MelSpectrogram::GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + const bool useHTKMethod) + { + UNUSED(leftMel); + UNUSED(rightMel); + UNUSED(useHTKMethod); + + /* By default, no normalisation => return 1 */ + return 1.f; + } + + void MelSpectrogram::_InitMelFilterBank() + { + if (!this->_IsMelFilterBankInited()) { + this->_m_melFilterBank = this->_CreateMelFilterBank(); + this->_m_filterBankInitialised = true; + } + } + + bool MelSpectrogram::_IsMelFilterBankInited() + { + return this->_m_filterBankInitialised; + } + + std::vector MelSpectrogram::ComputeMelSpec(const std::vector& audioData, float trainingMean) + { + this->_InitMelFilterBank(); + + /* TensorFlow way of normalizing .wav data to (-1, 1). */ + constexpr float normaliser = 1.0/(1<<15); + for (size_t i = 0; i < this->_m_params.m_frameLen; ++i) { + this->_m_frame[i] = static_cast(audioData[i]) * normaliser; + } + + /* Apply window function to input frame. */ + for(size_t i = 0; i < this->_m_params.m_frameLen; ++i) { + this->_m_frame[i] *= this->_m_windowFunc[i]; + } + + /* Set remaining frame values to 0. */ + std::fill(this->_m_frame.begin() + this->_m_params.m_frameLen,this->_m_frame.end(), 0); + + /* Compute FFT. */ + math::MathUtils::FftF32(this->_m_frame, this->_m_buffer, this->_m_fftInstance); + + /* Convert to power spectrum. */ + this->_ConvertToPowerSpectrum(); + + /* Apply mel filterbanks. */ + if (!this->ApplyMelFilterBank(this->_m_buffer, + this->_m_melFilterBank, + this->_m_filterBankFilterFirst, + this->_m_filterBankFilterLast, + this->_m_melEnergies)) { + printf_err("Failed to apply MEL filter banks\n"); + } + + /* Convert to logarithmic scale */ + this->ConvertToLogarithmicScale(this->_m_melEnergies); + + /* Perform mean subtraction. */ + for (auto& energy:this->_m_melEnergies) { + energy -= trainingMean; + } + + return this->_m_melEnergies; + } + + std::vector> MelSpectrogram::_CreateMelFilterBank() + { + size_t numFftBins = this->_m_params.m_frameLenPadded / 2; + float fftBinWidth = static_cast(this->_m_params.m_samplingFreq) / this->_m_params.m_frameLenPadded; + + float melLowFreq = MelSpectrogram::MelScale(this->_m_params.m_melLoFreq, + this->_m_params.m_useHtkMethod); + float melHighFreq = MelSpectrogram::MelScale(this->_m_params.m_melHiFreq, + this->_m_params.m_useHtkMethod); + float melFreqDelta = (melHighFreq - melLowFreq) / (this->_m_params.m_numFbankBins + 1); + + std::vector thisBin = std::vector(numFftBins); + std::vector> melFilterBank( + this->_m_params.m_numFbankBins); + this->_m_filterBankFilterFirst = + std::vector(this->_m_params.m_numFbankBins); + this->_m_filterBankFilterLast = + std::vector(this->_m_params.m_numFbankBins); + + for (size_t bin = 0; bin < this->_m_params.m_numFbankBins; bin++) { + float leftMel = melLowFreq + bin * melFreqDelta; + float centerMel = melLowFreq + (bin + 1) * melFreqDelta; + float rightMel = melLowFreq + (bin + 2) * melFreqDelta; + + int32_t firstIndex = -1; + int32_t lastIndex = -1; + const float normaliser = this->GetMelFilterBankNormaliser(leftMel, rightMel, this->_m_params.m_useHtkMethod); + + for (size_t i = 0; i < numFftBins; ++i) { + float freq = (fftBinWidth * i); /* Center freq of this fft bin. */ + float mel = MelSpectrogram::MelScale(freq, this->_m_params.m_useHtkMethod); + thisBin[i] = 0.0; + + if (mel > leftMel && mel < rightMel) { + float weight; + if (mel <= centerMel) { + weight = (mel - leftMel) / (centerMel - leftMel); + } else { + weight = (rightMel - mel) / (rightMel - centerMel); + } + + thisBin[i] = weight * normaliser; + if (firstIndex == -1) { + firstIndex = i; + } + lastIndex = i; + } + } + + this->_m_filterBankFilterFirst[bin] = firstIndex; + this->_m_filterBankFilterLast[bin] = lastIndex; + + /* Copy the part we care about. */ + for (int32_t i = firstIndex; i <= lastIndex; ++i) { + melFilterBank[bin].push_back(thisBin[i]); + } + } + + return melFilterBank; + } + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/ad/src/UseCaseHandler.cc b/source/use_case/ad/src/UseCaseHandler.cc new file mode 100644 index 0000000..c18a0a4 --- /dev/null +++ b/source/use_case/ad/src/UseCaseHandler.cc @@ -0,0 +1,422 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseHandler.hpp" + +#include "AdModel.hpp" +#include "InputFiles.hpp" +#include "Classifier.hpp" +#include "hal.h" +#include "AdMelSpectrogram.hpp" +#include "AudioUtils.hpp" +#include "UseCaseCommonUtils.hpp" +#include "AdPostProcessing.hpp" + +namespace arm { +namespace app { + + /** + * @brief Helper function to increment current audio clip index + * @param[in/out] ctx pointer to the application context object + **/ + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx); + + /** + * @brief Helper function to set the audio clip index + * @param[in/out] ctx pointer to the application context object + * @param[in] idx value to be set + * @return true if index is set, false otherwise + **/ + static bool _SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx); + + /** + * @brief Presents inference results using the data presentation + * object. + * @param[in] platform reference to the hal platform object + * @param[in] result average sum of classification results + * @param[in] threhsold if larger than this value we have an anomaly + * @return true if successful, false otherwise + **/ + static bool _PresentInferenceResult(hal_platform& platform, float result, float threshold); + + /** + * @brief Returns a function to perform feature calculation and populates input tensor data with + * MelSpe data. + * + * Input tensor data type check is performed to choose correct MFCC feature data type. + * If tensor has an integer data type then original features are quantised. + * + * Warning: mfcc calculator provided as input must have the same life scope as returned function. + * + * @param[in] mfcc MFCC feature calculator. + * @param[in/out] inputTensor Input tensor pointer to store calculated features. + * @param[i] cacheSize Size of the feture vectors cache (number of feature vectors). + * @return function function to be called providing audio sample and sliding window index. + */ + static std::function&, int, bool, size_t, size_t)> + GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, + TfLiteTensor* inputTensor, + size_t cacheSize, + float trainingMean); + + /* Vibration classification handler */ + bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) + { + auto& platform = ctx.Get("platform"); + + constexpr uint32_t dataPsnTxtInfStartX = 20; + constexpr uint32_t dataPsnTxtInfStartY = 40; + + platform.data_psn->clear(COLOR_BLACK); + + auto& model = ctx.Get("model"); + + /* If the request has a valid size, set the audio index */ + if (clipIndex < NUMBER_OF_FILES) { + if (!_SetAppCtxClipIdx(ctx, clipIndex)) { + return false; + } + } + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + const auto frameLength = ctx.Get("frameLength"); + const auto frameStride = ctx.Get("frameStride"); + const auto scoreThreshold = ctx.Get("scoreThreshold"); + const float trainingMean = ctx.Get("trainingMean"); + auto startClipIdx = ctx.Get("clipIndex"); + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); + + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } + + TfLiteIntArray* inputShape = model.GetInputShape(0); + const uint32_t kNumRows = inputShape->data[1]; + const uint32_t kNumCols = inputShape->data[2]; + + audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength); + melSpec.Init(); + + /* Deduce the data length required for 1 inference from the network parameters. */ + const uint8_t inputResizeScale = 2; + const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength; + + /* We are choosing to move by 20 frames across the audio for each inference. */ + const uint8_t nMelSpecVectorsInAudioStride = 20; + + auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride; + + do { + auto currentIndex = ctx.Get("clipIndex"); + + /* Get the output index to look at based on id in the filename. */ + int8_t machineOutputIndex = OutputIndexFromFileName(get_filename(currentIndex)); + if (machineOutputIndex == -1) { + return false; + } + + /* Creating a Mel Spectrogram sliding window for the data required for 1 inference. + * "resizing" done here by multiplying stride by resize scale. */ + auto audioMelSpecWindowSlider = audio::SlidingWindow( + get_audio_array(currentIndex), + audioDataWindowSize, frameLength, + frameStride * inputResizeScale); + + /* Creating a sliding window through the whole audio clip. */ + auto audioDataSlider = audio::SlidingWindow( + get_audio_array(currentIndex), + get_audio_array_size(currentIndex), + audioDataWindowSize, audioDataStride); + + /* Calculate number of the feature vectors in the window overlap region taking into account resizing. + * These feature vectors will be reused.*/ + auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale); + + /* Construct feature calculation function. */ + auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor, + numberOfReusedFeatureVectors, trainingMean); + if (!melSpecFeatureCalc){ + return false; + } + + /* Result is an averaged sum over inferences. */ + float result = 0; + + /* Display message on the LCD - inference running. */ + std::string str_inf{"Running inference... "}; + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + info("Running inference on audio clip %u => %s\n", currentIndex, get_filename(currentIndex)); + + /* Start sliding through audio clip. */ + while (audioDataSlider.HasNext()) { + const int16_t *inferenceWindow = audioDataSlider.Next(); + + /* We moved to the next window - set the features sliding to the new address. */ + audioMelSpecWindowSlider.Reset(inferenceWindow); + + /* The first window does not have cache ready. */ + bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; + + /* Start calculating features inside one audio sliding window. */ + while (audioMelSpecWindowSlider.HasNext()) { + const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next(); + std::vector melSpecAudioData = std::vector(melSpecWindow, + melSpecWindow + frameLength); + + /* Compute features for this window and write them to input tensor. */ + melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(), + useCache, nMelSpecVectorsInAudioStride, inputResizeScale); + } + + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); + + /* Run inference over this audio clip sliding window */ + arm::app::RunInference(platform, model); + + /* Use the negative softmax score of the corresponding index as the outlier score */ + std::vector dequantOutput = Dequantize(outputTensor); + Softmax(dequantOutput); + result += -dequantOutput[machineOutputIndex]; + +#if VERIFY_TEST_OUTPUT + arm::app::DumpTensor(outputTensor); +#endif /* VERIFY_TEST_OUTPUT */ + } /* while (audioDataSlider.HasNext()) */ + + /* Use average over whole clip as final score. */ + result /= (audioDataSlider.TotalStrides() + 1); + + /* Erase. */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + ctx.Set("result", result); + if (!_PresentInferenceResult(platform, result, scoreThreshold)) { + return false; + } + + _IncrementAppCtxClipIdx(ctx); + + } while (runAll && ctx.Get("clipIndex") != startClipIdx); + + return true; + } + + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx) + { + auto curAudioIdx = ctx.Get("clipIndex"); + + if (curAudioIdx + 1 >= NUMBER_OF_FILES) { + ctx.Set("clipIndex", 0); + return; + } + ++curAudioIdx; + ctx.Set("clipIndex", curAudioIdx); + } + + static bool _SetAppCtxClipIdx(ApplicationContext& ctx, const uint32_t idx) + { + if (idx >= NUMBER_OF_FILES) { + printf_err("Invalid idx %u (expected less than %u)\n", + idx, NUMBER_OF_FILES); + return false; + } + ctx.Set("clipIndex", idx); + return true; + } + + static bool _PresentInferenceResult(hal_platform& platform, float result, float threshold) + { + constexpr uint32_t dataPsnTxtStartX1 = 20; + constexpr uint32_t dataPsnTxtStartY1 = 30; + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment */ + + platform.data_psn->set_text_color(COLOR_GREEN); + + /* Display each result */ + uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; + + std::string resultStr = std::string{"Average anomaly score is: "} + std::to_string(result) + + std::string("\n") + std::string("Anomaly threshold is: ") + std::to_string(threshold) + + std::string("\n"); + + if (result > threshold) { + resultStr += std::string("Anomaly detected!"); + } else { + resultStr += std::string("Everything fine, no anomaly detected!"); + } + + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); + + info("%s\n", resultStr.c_str()); + + return true; + } + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T feature vector type. + * @param inputTensor model input tensor pointer. + * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. + * @param compute features calculator function. + * @return lambda function to compute features. + */ + template + std::function&, size_t, bool, size_t, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function (std::vector& )> compute) + { + /* Feature cache to be captured by lambda function*/ + static std::vector> featureCache = std::vector>(cacheSize); + + return [=](std::vector& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex, + size_t resizeScale) + { + T *tensorData = tflite::GetTensorData(inputTensor); + std::vector features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size() / resizeScale; + auto sizeBytes = sizeof(T); + + /* Input should be transposed and "resized" by skipping elements. */ + for (size_t outIndex = 0; outIndex < size; outIndex++) { + std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes); + } + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex / resizeScale) { + featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features); + } + }; + } + + template std::function&, size_t , bool, size_t, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector&)> compute); + + template std::function&, size_t , bool, size_t, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector&)> compute); + + template std::function&, size_t , bool, size_t, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector&)> compute); + + template std::function&, size_t, bool, size_t, size_t)> + _FeatureCalc(TfLiteTensor *inputTensor, + size_t cacheSize, + std::function(std::vector&)> compute); + + + static std::function&, int, bool, size_t, size_t)> + GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean) + { + std::function&, size_t, bool, size_t, size_t)> melSpecFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + + auto *quantParams = (TfLiteAffineQuantization *) quant.params; + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + melSpecFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &melSpec](std::vector& audioDataWindow) { + return melSpec.MelSpecComputeQuant(audioDataWindow, + quantScale, + quantOffset, + trainingMean); + } + ); + break; + } + case kTfLiteUInt8: { + melSpecFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &melSpec](std::vector& audioDataWindow) { + return melSpec.MelSpecComputeQuant(audioDataWindow, + quantScale, + quantOffset, + trainingMean); + } + ); + break; + } + case kTfLiteInt16: { + melSpecFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &melSpec](std::vector& audioDataWindow) { + return melSpec.MelSpecComputeQuant(audioDataWindow, + quantScale, + quantOffset, + trainingMean); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + + + } else { + melSpecFeatureCalc = melSpecFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &melSpec](std::vector& audioDataWindow) { + return melSpec.ComputeMelSpec(audioDataWindow, + trainingMean); + }); + } + return melSpecFeatureCalc; + } + +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/ad/usecase.cmake b/source/use_case/ad/usecase.cmake new file mode 100644 index 0000000..46e4101 --- /dev/null +++ b/source/use_case/ad/usecase.cmake @@ -0,0 +1,111 @@ +#---------------------------------------------------------------------------- +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#---------------------------------------------------------------------------- + +# If the path to a directory or source file has been defined, +# get the type here (FILEPATH or PATH): +if (DEFINED ${use_case}_FILE_PATH) + get_path_type(${${use_case}_FILE_PATH} PATH_TYPE) + + # Set the default type if path is not a dir or file path (or undefined) + if (NOT ${PATH_TYPE} STREQUAL PATH AND NOT ${PATH_TYPE} STREQUAL FILEPATH) + message(FATAL_ERROR "Invalid ${use_case}_FILE_PATH. It should be a dir or file path.") + endif() +else() + # Default is a directory path + set(PATH_TYPE PATH) +endif() + +message(STATUS "${use_case}_FILE_PATH is of type: ${PATH_TYPE}") + +USER_OPTION(${use_case}_FILE_PATH "Directory with custom WAV input files, or path to a single input WAV file, to use in the evaluation application." + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/samples/ + ${PATH_TYPE}) + +USER_OPTION(${use_case}_AUDIO_RATE "Specify the target sampling rate. Default is 16000." + 16000 + STRING) + +USER_OPTION(${use_case}_AUDIO_MONO "Specify if the audio needs to be converted to mono. Default is ON." + ON + BOOL) + +USER_OPTION(${use_case}_AUDIO_OFFSET "Specify the offset to start reading after this time (in seconds). Default is 0." + 0 + STRING) + +USER_OPTION(${use_case}_AUDIO_DURATION "Specify the audio duration to load (in seconds). If set to 0 the entire audio will be processed." + 0 + STRING) + +USER_OPTION(${use_case}_AUDIO_RES_TYPE "Specify re-sampling algorithm to use. By default is 'kaiser_best'." + kaiser_best + STRING) + +USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples to use. By default is amount needed to do one inference, + if the audio is shorter then it will be automatically padded." + 33280 + STRING) + +USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD "Specify the score threshold for a result to be deemed anomalous." + -0.8 + STRING) + +generate_audio_code(${${use_case}_FILE_PATH} ${SRC_GEN_DIR} ${INC_GEN_DIR} + ${${use_case}_AUDIO_RATE} + ${${use_case}_AUDIO_MONO} + ${${use_case}_AUDIO_OFFSET} + ${${use_case}_AUDIO_DURATION} + ${${use_case}_AUDIO_RES_TYPE} + ${${use_case}_AUDIO_MIN_SAMPLES}) + +USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen model" + 0x00200000 + STRING) + +# If there is no tflite file pointed to +if (NOT DEFINED ${use_case}_MODEL_TFLITE_PATH) + + set(MODEL_RESOURCES_DIR ${DOWNLOAD_DEP_DIR}/${use_case}) + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR}) + set(MODEL_FILENAME ad_med_nov11_int8.tflite) + set(DEFAULT_MODEL_PATH ${MODEL_RESOURCES_DIR}/${MODEL_FILENAME}) + + # TODO: Download the model here for this use case when available on Model Zoo. + # For now we write a place holder file. + file(WRITE ${DEFAULT_MODEL_PATH} "Placeholder") +else() + set(DEFAULT_MODEL_PATH "N/A") +endif() + +set(EXTRA_MODEL_CODE + "/* Model parameters for ${use_case} */" + "extern const int g_FrameLength = 1024" + "extern const int g_FrameStride = 512" + "extern const float g_ScoreThreshold = ${${use_case}_MODEL_SCORE_THRESHOLD}" + "extern const float g_TrainingMean = -30" + ) + +USER_OPTION(${use_case}_MODEL_TFLITE_PATH "NN models file to be used in the evaluation application. Model files must be in tflite format." + ${DEFAULT_MODEL_PATH} + FILEPATH) + +# Generate model file +generate_tflite_code( + MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH} + DESTINATION ${SRC_GEN_DIR} + EXPRESSIONS ${EXTRA_MODEL_CODE} +) diff --git a/source/use_case/asr/include/AsrClassifier.hpp b/source/use_case/asr/include/AsrClassifier.hpp new file mode 100644 index 0000000..1a63814 --- /dev/null +++ b/source/use_case/asr/include/AsrClassifier.hpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_CLASSIFIER_HPP +#define ASR_CLASSIFIER_HPP + +#include "Classifier.hpp" + +namespace arm { +namespace app { + + class AsrClassifier : public Classifier { + public: + /** + * @brief Gets the top N classification results from the + * output vector. + * @param[in] outputTensor Inference output tensor from an NN model. + * @param[out] vecResults A vector of classification results + * populated by this function. + * @param[in] labels Labels vector to match classified classes + * @param[in] topNCount Number of top classifications to pick. + * @return true if successful, false otherwise. + **/ + bool GetClassificationResults( + TfLiteTensor* outputTensor, + std::vector& vecResults, + const std::vector & labels, uint32_t topNCount) override; + + private: + /** + * @brief Utility function that gets the top 1 classification results from the + * output tensor (vector of vector). + * @param[in] tensor Inference output tensor from an NN model. + * @param[out] vecResults Vector of classification results populated by this function. + * @param[in] labels Labels vector to match classified classes. + * @param[in] scale Quantization scale. + * @param[in] zeroPoint Quantization zero point. + * @return true if successful, false otherwise. + **/ + template + bool _GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint); + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_CLASSIFIER_HPP */ \ No newline at end of file diff --git a/source/use_case/asr/include/AsrResult.hpp b/source/use_case/asr/include/AsrResult.hpp new file mode 100644 index 0000000..b12ed7d --- /dev/null +++ b/source/use_case/asr/include/AsrResult.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_RESULT_HPP +#define ASR_RESULT_HPP + +#include "ClassificationResult.hpp" + +#include + +namespace arm { +namespace app { +namespace asr { + + using ResultVec = std::vector < arm::app::ClassificationResult >; + + /* Structure for holding ASR result. */ + class AsrResult { + + public: + ResultVec m_resultVec; /* Container for "thresholded" classification results. */ + float m_timeStamp; /* Audio timestamp for this result. */ + uint32_t m_inferenceNumber; /* Corresponding inference number. */ + float m_threshold; /* Threshold value for `m_resultVec.` */ + + AsrResult() = delete; + AsrResult(ResultVec& resultVec, + const float timestamp, + const uint32_t inferenceIdx, + const float scoreThreshold) { + + this->m_threshold = scoreThreshold; + this->m_timeStamp = timestamp; + this->m_inferenceNumber = inferenceIdx; + + this->m_resultVec = ResultVec(); + for (auto& i : resultVec) { + if (i.m_normalisedVal >= this->m_threshold) { + this->m_resultVec.emplace_back(i); + } + } + } + ~AsrResult() = default; + }; + +} /* namespace asr */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_RESULT_HPP */ \ No newline at end of file diff --git a/source/use_case/asr/include/OutputDecode.hpp b/source/use_case/asr/include/OutputDecode.hpp new file mode 100644 index 0000000..6095531 --- /dev/null +++ b/source/use_case/asr/include/OutputDecode.hpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_OUTPUT_DECODE_HPP +#define ASR_OUTPUT_DECODE_HPP + +#include "AsrClassifier.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + /** + * @brief Gets the top N classification results from the + * output vector. + * @param[in] tensor Label output from classifier. + * @return true if successful, false otherwise. + **/ + std::string DecodeOutput(const std::vector& vecResults); + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_OUTPUT_DECODE_HPP */ \ No newline at end of file diff --git a/source/use_case/asr/include/UseCaseHandler.hpp b/source/use_case/asr/include/UseCaseHandler.hpp new file mode 100644 index 0000000..75052c7 --- /dev/null +++ b/source/use_case/asr/include/UseCaseHandler.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_EVT_HANDLER_HPP +#define ASR_EVT_HANDLER_HPP + +#include "AppContext.hpp" + +namespace arm { +namespace app { + + /** + * @brief Handles the inference event. + * @param[in] ctx Pointer to the application context. + * @param[in] clipIndex Index to the audio clip to classify. + * @param[in] runAll Flag to request classification of all the available audio clips. + * @return true or false based on execution success. + **/ + bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_EVT_HANDLER_HPP */ diff --git a/source/use_case/asr/include/Wav2LetterMfcc.hpp b/source/use_case/asr/include/Wav2LetterMfcc.hpp new file mode 100644 index 0000000..3cb43b9 --- /dev/null +++ b/source/use_case/asr/include/Wav2LetterMfcc.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_WAV2LETTER_MFCC_HPP +#define ASR_WAV2LETTER_MFCC_HPP + +#include "Mfcc.hpp" + +namespace arm { +namespace app { +namespace audio { + + /* Class to provide Wav2Letter specific MFCC calculation requirements. */ + class Wav2LetterMFCC : public MFCC { + + public: + static constexpr uint32_t ms_defaultSamplingFreq = 16000; + static constexpr uint32_t ms_defaultNumFbankBins = 128; + static constexpr uint32_t ms_defaultMelLoFreq = 0; + static constexpr uint32_t ms_defaultMelHiFreq = 8000; + static constexpr bool ms_defaultUseHtkMethod = false; + + explicit Wav2LetterMFCC(const size_t numFeats, const size_t frameLen) + : MFCC(MfccParams( + ms_defaultSamplingFreq, ms_defaultNumFbankBins, + ms_defaultMelLoFreq, ms_defaultMelHiFreq, + numFeats, frameLen, ms_defaultUseHtkMethod)) + {} + + Wav2LetterMFCC() = delete; + ~Wav2LetterMFCC() = default; + + protected: + + /** + * @brief Overrides base class implementation of this function. + * @param[in] fftVec Vector populated with FFT magnitudes + * @param[in] melFilterBank 2D Vector with filter bank weights + * @param[in] filterBankFilterFirst Vector containing the first indices of filter bank + * to be used for each bin. + * @param[in] filterBankFilterLast Vector containing the last indices of filter bank + * to be used for each bin. + * @param[out] melEnergies Pre-allocated vector of MEL energies to be + * populated. + * @return true if successful, false otherwise + */ + bool ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) override; + + /** + * @brief Override for the base class implementation convert mel + * energies to logarithmic scale. The difference from + * default behaviour is that the power is converted to dB + * and subsequently clamped. + * @param[in,out] melEnergies 1D vector of Mel energies + **/ + void ConvertToLogarithmicScale(std::vector& melEnergies) override; + + /** + * @brief Create a matrix used to calculate Discrete Cosine + * Transform. Override for the base class' default + * implementation as the first and last elements + * use a different normaliser. + * @param[in] inputLength input length of the buffer on which + * DCT will be performed + * @param[in] coefficientCount Total coefficients per input length. + * @return 1D vector with inputLength x coefficientCount elements + * populated with DCT coefficients. + */ + std::vector CreateDCTMatrix(int32_t inputLength, + int32_t coefficientCount) override; + + /** + * @brief Given the low and high Mel values, get the normaliser + * for weights to be applied when populating the filter + * bank. Override for the base class implementation. + * @param[in] leftMel Low Mel frequency value. + * @param[in] rightMel High Mel frequency value. + * @param[in] useHTKMethod bool to signal if HTK method is to be + * used for calculation. + * @return Value to use for normalising. + */ + float GetMelFilterBankNormaliser(const float& leftMel, + const float& rightMel, + bool useHTKMethod) override; + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_WAV2LETTER_MFCC_HPP */ \ No newline at end of file diff --git a/source/use_case/asr/include/Wav2LetterModel.hpp b/source/use_case/asr/include/Wav2LetterModel.hpp new file mode 100644 index 0000000..b801e10 --- /dev/null +++ b/source/use_case/asr/include/Wav2LetterModel.hpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved.rved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_WAV2LETTER_MODEL_HPP +#define ASR_WAV2LETTER_MODEL_HPP + +#include "Model.hpp" + +extern const int g_FrameLength; +extern const int g_FrameStride; +extern const float g_ScoreThreshold; +extern const int g_ctxLen; + +namespace arm { +namespace app { + + class Wav2LetterModel : public Model { + + public: + /* Indices for the expected model - based on input and output tensor shapes */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_outputRowsIdx = 2; + static constexpr uint32_t ms_outputColsIdx = 3; + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int _ms_maxOpCnt = 5; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver<_ms_maxOpCnt> _m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_WAV2LETTER_MODEL_HPP */ diff --git a/source/use_case/asr/include/Wav2LetterPostprocess.hpp b/source/use_case/asr/include/Wav2LetterPostprocess.hpp new file mode 100644 index 0000000..69567a3 --- /dev/null +++ b/source/use_case/asr/include/Wav2LetterPostprocess.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_WAV2LETTER_POSTPROCESS_HPP +#define ASR_WAV2LETTER_POSTPROCESS_HPP + +#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers. */ +#include "hal.h" /* stdout facility. */ + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + /** + * @brief Helper class to manage tensor post-processing for "wav2letter" + * output. + */ + class Postprocess { + public: + /** + * @brief Constructor + * @param[in] contextLen Left and right context length for + * output tensor. + * @param[in] innerLen This is the length of the section + * between left and right context. + **/ + Postprocess(uint32_t contextLen, + uint32_t innerLen, + uint32_t blankTokenIdx); + + Postprocess() = delete; + ~Postprocess() = default; + + /** + * @brief Erases the required part of the tensor based + * on context lengths set up during initialisation. + * @param[in] tensor Pointer to the tensor. + * @param[in] axisIdx Index of the axis on which erase is + * performed. + * @param[in] lastIteration Flag to signal this is the + * last iteration in which case + * the right context is preserved. + * @return true if successful, false otherwise. + */ + bool Invoke(TfLiteTensor* tensor, + uint32_t axisIdx, + bool lastIteration = false); + + private: + uint32_t _m_contextLen; /* lengths of left and right contexts. */ + uint32_t _m_innerLen; /* Length of inner context. */ + uint32_t _m_totalLen; /* Total length of the required axis. */ + uint32_t _m_countIterations; /* Current number of iterations. */ + uint32_t _m_blankTokenIdx; /* Index of the labels blank token. */ + /** + * @brief Checks if the tensor and axis index are valid + * inputs to the object - based on how it has been + * initialised. + * @return true if valid, false otherwise. + */ + bool _IsInputValid(TfLiteTensor* tensor, + uint32_t axisIdx) const; + + /** + * @brief Gets the tensor data element size in bytes based + * on the tensor type. + * @return Size in bytes, 0 if not supported. + */ + uint32_t _GetTensorElementSize(TfLiteTensor* tensor); + + /** + * @brief Erases sections from the data assuming row-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. + */ + bool _EraseSectionsRowWise(uint8_t* ptrData, + uint32_t strideSzBytes, + bool lastIteration); + + /** + * @brief Erases sections from the data assuming col-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. + */ + bool _EraseSectionsColWise(uint8_t* ptrData, + uint32_t strideSzBytes, + bool lastIteration); + }; + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_WAV2LETTER_POSTPROCESS_HPP */ \ No newline at end of file diff --git a/source/use_case/asr/include/Wav2LetterPreprocess.hpp b/source/use_case/asr/include/Wav2LetterPreprocess.hpp new file mode 100644 index 0000000..8a4e0b7 --- /dev/null +++ b/source/use_case/asr/include/Wav2LetterPreprocess.hpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_WAV2LETTER_PREPROCESS_HPP +#define ASR_WAV2LETTER_PREPROCESS_HPP + +#include "Wav2LetterModel.hpp" +#include "Wav2LetterMfcc.hpp" +#include "AudioUtils.hpp" +#include "DataStructures.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + /* Class to facilitate pre-processing calculation for Wav2Letter model + * for ASR. */ + using AudioWindow = SlidingWindow ; + + class Preprocess { + public: + /** + * @brief Constructor. + * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] windowLen Number of elements in a window. + * @param[in] windowStride Stride (in number of elements) for + * moving the window. + * @param[in] numMfccVectors Number of MFCC vectors per window. + */ + Preprocess( + uint32_t numMfccFeatures, + uint32_t windowLen, + uint32_t windowStride, + uint32_t numMfccVectors); + Preprocess() = delete; + ~Preprocess() = default; + + /** + * @brief Calculates the features required from audio data. This + * includes MFCC, first and second order deltas, + * normalisation and finally, quantisation. The tensor is + * populated with feature from a given window placed along + * in a single row. + * @param[in] audioData Pointer to the first element of audio data. + * @param[in] audioDataLen Number of elements in the audio data. + * @param[in] tensor Tensor to be populated. + * @return true if successful, false in case of error. + */ + bool Invoke(const int16_t * audioData, + uint32_t audioDataLen, + TfLiteTensor * tensor); + + protected: + /** + * @brief Computes the first and second order deltas for the + * MFCC buffers - they are assumed to be populated. + * + * @param[in] mfcc MFCC buffers. + * @param[out] delta1 Result of the first diff computation. + * @param[out] delta2 Result of the second diff computation. + * @return true if successful, false otherwise. + */ + static bool _ComputeDeltas(Array2d& mfcc, + Array2d& delta1, + Array2d& delta2); + + /** + * @brief Given a 2D vector of floats, computes the mean. + * @param[in] vec Vctor of vector of floats. + * @return Mean value. + */ + static float _GetMean(Array2d& vec); + + /** + * @brief Given a 2D vector of floats, computes the stddev. + * @param[in] vec Vector of vector of floats. + * @param[in] mean Mean value of the vector passed in. + * @return stddev value. + */ + static float _GetStdDev(Array2d& vec, + float mean); + + /** + * @brief Given a 2D vector of floats, normalises it using + * the mean and the stddev. + * @param[in,out] vec Vector of vector of floats. + */ + static void _NormaliseVec(Array2d& vec); + + /** + * @brief Normalises the MFCC and delta buffers. + */ + void _Normalise(); + + /** + * @brief Given the quantisation and data type limits, computes + * the quantised values of a floating point input data. + * @param[in] elem Element to be quantised. + * @param[in] quantScale Scale. + * @param[in] quantOffset Offset. + * @param[in] minVal Numerical limit - minimum. + * @param[in] maxVal Numerical limit - maximum. + * @return Floating point quantised value. + */ + static float _GetQuantElem( + float elem, + float quantScale, + int quantOffset, + float minVal, + float maxVal); + + /** + * @brief Quantises the MFCC and delta buffers, and places them + * in the output buffer. While doing so, it transposes + * the data. Reason: Buffers in this class are arranged + * for "time" axis to be row major. Primary reason for + * this being the convolution speed up (as we can use + * contiguous memory). The output, however, requires the + * time axis to be in column major arrangement. + * @param[in] outputBuf Pointer to the output buffer. + * @param[in] outputBufSz Output buffer's size. + * @param[in] quantScale Quantisation scale. + * @param[in] quantOffset Quantisation offset. + */ + template + bool _Quantise( + T * outputBuf, + const uint32_t outputBufSz, + const float quantScale, + const int quantOffset) + { + /* Check the output size will fit everything. */ + if (outputBufSz < (this->_m_mfccBuf.size(0) * 3 * sizeof(T))) { + printf_err("Tensor size too small for features\n"); + return false; + } + + /* Populate. */ + T * outputBufMfcc = outputBuf; + T * outputBufD1 = outputBuf + this->_m_numMfccFeats; + T * outputBufD2 = outputBufD1 + this->_m_numMfccFeats; + const uint32_t ptrIncr = this->_m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ + + const float minVal = std::numeric_limits::min(); + const float maxVal = std::numeric_limits::max(); + + /* Need to transpose while copying and concatenating the tensor. */ + for (uint32_t j = 0; j < this->_m_numFeatVectors; ++j) { + for (uint32_t i = 0; i < this->_m_numMfccFeats; ++i) { + *outputBufMfcc++ = static_cast(Preprocess::_GetQuantElem( + this->_m_mfccBuf(i, j), quantScale, + quantOffset, minVal, maxVal)); + *outputBufD1++ = static_cast(Preprocess::_GetQuantElem( + this->_m_delta1Buf(i, j), quantScale, + quantOffset, minVal, maxVal)); + *outputBufD2++ = static_cast(Preprocess::_GetQuantElem( + this->_m_delta2Buf(i, j), quantScale, + quantOffset, minVal, maxVal)); + } + outputBufMfcc += ptrIncr; + outputBufD1 += ptrIncr; + outputBufD2 += ptrIncr; + } + + return true; + } + + private: + Wav2LetterMFCC _m_mfcc; /* MFCC instance. */ + + /* Actual buffers to be populated. */ + Array2d _m_mfccBuf; /* Contiguous buffer 1D: MFCC */ + Array2d _m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ + Array2d _m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ + + uint32_t _m_windowLen; /* Window length for MFCC. */ + uint32_t _m_windowStride; /* Window stride len for MFCC. */ + uint32_t _m_numMfccFeats; /* Number of MFCC features per window. */ + uint32_t _m_numFeatVectors; /* Number of _m_numMfccFeats. */ + AudioWindow _m_window; /* Sliding window. */ + + }; + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_WAV2LETTER_PREPROCESS_HPP */ \ No newline at end of file diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc new file mode 100644 index 0000000..7377d30 --- /dev/null +++ b/source/use_case/asr/src/AsrClassifier.cc @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AsrClassifier.hpp" + +#include "hal.h" +#include "TensorFlowLiteMicro.hpp" +#include "Wav2LetterModel.hpp" + +template +bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint) +{ + const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx]; + const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]; + + /* NOTE: tensor's size verification against labels should be + * checked by the calling/public function. */ + if (nLetters < 1) { + return false; + } + + /* Final results' container. */ + vecResults = std::vector(nElems); + + T* tensorData = tflite::GetTensorData(tensor); + + /* Get the top 1 results. */ + for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) { + std::pair top_1 = std::make_pair(tensorData[row + 0], 0); + + for (uint32_t j = 1; j < nLetters; ++j) { + if (top_1.first < tensorData[row + j]) { + top_1.first = tensorData[row + j]; + top_1.second = j; + } + } + + double score = static_cast (top_1.first); + vecResults[i].m_normalisedVal = scale * (score - zeroPoint); + vecResults[i].m_label = labels[top_1.second]; + vecResults[i].m_labelIdx = top_1.second; + } + + return true; +} +template bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint); +template bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint); + +bool arm::app::AsrClassifier::GetClassificationResults( + TfLiteTensor* outputTensor, + std::vector& vecResults, + const std::vector & labels, uint32_t topNCount) +{ + vecResults.clear(); + + constexpr int minTensorDims = static_cast( + (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)? + arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx); + + constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx; + + /* Sanity checks. */ + if (outputTensor == nullptr) { + printf_err("Output vector is null pointer.\n"); + return false; + } else if (outputTensor->dims->size < minTensorDims) { + printf_err("Output tensor expected to be %dD\n", minTensorDims); + return false; + } else if (static_cast(outputTensor->dims->data[outColsIdx]) < topNCount) { + printf_err("Output vectors are smaller than %u\n", topNCount); + return false; + } else if (static_cast(outputTensor->dims->data[outColsIdx]) != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } + + if (topNCount != 1) { + warn("TopNCount value ignored in this implementation\n"); + } + + /* To return the floating point values, we need quantization parameters. */ + QuantParams quantParams = GetTensorQuantParams(outputTensor); + + bool resultState; + + switch (outputTensor->type) { + case kTfLiteUInt8: + resultState = this->_GetTopResults( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + case kTfLiteInt8: + resultState = this->_GetTopResults( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + default: + printf_err("Tensor type %s not supported by classifier\n", + TfLiteTypeGetName(outputTensor->type)); + return false; + } + + if (!resultState) { + printf_err("Failed to get sorted set\n"); + return false; + } + + return true; +} \ No newline at end of file diff --git a/source/use_case/asr/src/MainLoop.cc b/source/use_case/asr/src/MainLoop.cc new file mode 100644 index 0000000..ca777be --- /dev/null +++ b/source/use_case/asr/src/MainLoop.cc @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "hal.h" /* Brings in platform definitions. */ +#include "Labels.hpp" /* For label strings. */ +#include "UseCaseHandler.hpp" /* Handlers for different user options. */ +#include "Wav2LetterModel.hpp" /* Model class for running inference. */ +#include "UseCaseCommonUtils.hpp" /* Utils functions. */ +#include "AsrClassifier.hpp" /* Classifier. */ +#include "InputFiles.hpp" /* Generated audio clip header. */ +#include "Wav2LetterPreprocess.hpp" /* Pre-processing class. */ +#include "Wav2LetterPostprocess.hpp" /* Post-processing class. */ + +enum opcodes +{ + MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ + MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */ + MENU_OPT_RUN_INF_ALL, /* Run inference on all. */ + MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */ + MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio clips. */ +}; + +static void DisplayMenu() +{ + printf("\n\nUser input required\n"); + printf("Enter option number from:\n\n"); + printf(" %u. Classify next audio clip\n", MENU_OPT_RUN_INF_NEXT); + printf(" %u. Classify audio clip at chosen index\n", MENU_OPT_RUN_INF_CHOSEN); + printf(" %u. Run classification on all audio clips\n", MENU_OPT_RUN_INF_ALL); + printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO); + printf(" %u. List audio clips\n\n", MENU_OPT_LIST_AUDIO_CLIPS); + printf(" Choice: "); +} + +/** @brief Verify input and output tensor are of certain min dimensions. */ +static bool VerifyTensorDimensions(const arm::app::Model& model); + +/** @brief Gets the number of MFCC features for a single window. */ +static uint32_t GetNumMfccFeatures(const arm::app::Model& model); + +/** @brief Gets the number of MFCC feature vectors to be computed. */ +static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); + +/** @brief Gets the output context length (left and right) for post-processing. */ +static uint32_t GetOutputContextLen(const arm::app::Model& model, + uint32_t inputCtxLen); + +/** @brief Gets the output inner length for post-processing. */ +static uint32_t GetOutputInnerLen(const arm::app::Model& model, + uint32_t outputCtxLen); + +void main_loop(hal_platform& platform) +{ + arm::app::Wav2LetterModel model; /* Model wrapper object. */ + + /* Load the model. */ + if (!model.Init()) { + printf_err("Failed to initialise model\n"); + return; + } else if (!VerifyTensorDimensions(model)) { + printf_err("Model's input or output dimension verification failed\n"); + return; + } + + /* Initialise pre-processing. */ + arm::app::audio::asr::Preprocess prep( + GetNumMfccFeatures(model), + g_FrameLength, + g_FrameStride, + GetNumMfccFeatureVectors(model)); + + /* Initialise post-processing. */ + const uint32_t outputCtxLen = GetOutputContextLen(model, g_ctxLen); + const uint32_t blankTokenIdx = 28; + arm::app::audio::asr::Postprocess postp( + outputCtxLen, + GetOutputInnerLen(model, outputCtxLen), + blankTokenIdx); + + /* Instantiate application context. */ + arm::app::ApplicationContext caseContext; + std::vector labels; + GetLabelsVector(labels); + arm::app::AsrClassifier classifier; /* Classifier wrapper object. */ + + caseContext.Set("platform", platform); + caseContext.Set("model", model); + caseContext.Set("clipIndex", 0); + caseContext.Set("frameLength", g_FrameLength); + caseContext.Set("frameStride", g_FrameStride); + caseContext.Set("scoreThreshold", g_ScoreThreshold); /* Score threshold. */ + caseContext.Set("ctxLen", g_ctxLen); /* Left and right context length (MFCC feat vectors). */ + caseContext.Set&>("labels", labels); + caseContext.Set("classifier", classifier); + caseContext.Set("preprocess", prep); + caseContext.Set("postprocess", postp); + + bool executionSuccessful = true; + constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; + + /* Loop. */ + do { + int menuOption = MENU_OPT_RUN_INF_NEXT; + if (bUseMenu) { + DisplayMenu(); + menuOption = arm::app::ReadUserInputAsInt(platform); + printf("\n"); + } + switch (menuOption) { + case MENU_OPT_RUN_INF_NEXT: + executionSuccessful = ClassifyAudioHandler( + caseContext, + caseContext.Get("clipIndex"), + false); + break; + case MENU_OPT_RUN_INF_CHOSEN: { + printf(" Enter the audio clip index [0, %d]: ", + NUMBER_OF_FILES-1); + auto clipIndex = static_cast( + arm::app::ReadUserInputAsInt(platform)); + executionSuccessful = ClassifyAudioHandler(caseContext, + clipIndex, + false); + break; + } + case MENU_OPT_RUN_INF_ALL: + executionSuccessful = ClassifyAudioHandler( + caseContext, + caseContext.Get("clipIndex"), + true); + break; + case MENU_OPT_SHOW_MODEL_INFO: + executionSuccessful = model.ShowModelInfoHandler(); + break; + case MENU_OPT_LIST_AUDIO_CLIPS: + executionSuccessful = ListFilesHandler(caseContext); + break; + default: + printf("Incorrect choice, try again."); + break; + } + } while (executionSuccessful && bUseMenu); + info("Main loop terminated.\n"); +} + +static bool VerifyTensorDimensions(const arm::app::Model& model) +{ + /* Populate tensor related parameters. */ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + if (!outputTensor->dims) { + printf_err("Invalid output tensor dims\n"); + return false; + } else if (outputTensor->dims->size < 3) { + printf_err("Output tensor dimension should be >= 3\n"); + return false; + } + + return true; +} + +static uint32_t GetNumMfccFeatures(const arm::app::Model& model) +{ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; + if (0 != inputCols % 3) { + printf_err("Number of input columns is not a multiple of 3\n"); + } + return std::max(inputCols/3, 0); +} + +static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) +{ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; + return std::max(inputRows, 0); +} + +static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) +{ + const uint32_t inputRows = GetNumMfccFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above + * context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %u\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + + const float tensorColRatio = static_cast(inputRows)/ + static_cast(outputRows); + + return std::round(static_cast(inputCtxLen)/tensorColRatio); +} + +static uint32_t GetOutputInnerLen(const arm::app::Model& model, + const uint32_t outputCtxLen) +{ + constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + return (outputRows - (2 * outputCtxLen)); +} diff --git a/source/use_case/asr/src/OutputDecode.cc b/source/use_case/asr/src/OutputDecode.cc new file mode 100644 index 0000000..41fbe07 --- /dev/null +++ b/source/use_case/asr/src/OutputDecode.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "OutputDecode.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + std::string DecodeOutput(const std::vector& vecResults) + { + std::string CleanOutputBuffer; + + for (size_t i = 0; i < vecResults.size(); ++i) /* For all elements in vector. */ + { + while (i+1 < vecResults.size() && + vecResults[i].m_label == vecResults[i+1].m_label) /* While the current element is equal to the next, ignore it and move on. */ + { + ++i; + } + if (vecResults[i].m_label != "$") /* $ is a character used to represent unknown and double characters so should not be in output. */ + { + CleanOutputBuffer += vecResults[i].m_label; /* If the element is different to the next, it will be appended to CleanOutputBuffer. */ + } + } + + return CleanOutputBuffer; /* Return string type containing clean output. */ + } + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/asr/src/UseCaseHandler.cc b/source/use_case/asr/src/UseCaseHandler.cc new file mode 100644 index 0000000..e706eb8 --- /dev/null +++ b/source/use_case/asr/src/UseCaseHandler.cc @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseHandler.hpp" + +#include "InputFiles.hpp" +#include "AsrClassifier.hpp" +#include "Wav2LetterModel.hpp" +#include "hal.h" +#include "Wav2LetterMfcc.hpp" +#include "AudioUtils.hpp" +#include "UseCaseCommonUtils.hpp" +#include "AsrResult.hpp" +#include "Wav2LetterPreprocess.hpp" +#include "Wav2LetterPostprocess.hpp" +#include "OutputDecode.hpp" + +namespace arm { +namespace app { + + /** + * @brief Helper function to increment current audio clip index. + * @param[in,out] ctx Pointer to the application context object. + **/ + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx); + + /** + * @brief Helper function to set the audio clip index. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] idx Value to be set. + * @return true if index is set, false otherwise. + **/ + static bool _SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx); + + /** + * @brief Presents inference results using the data presentation + * object. + * @param[in] platform Reference to the hal platform object. + * @param[in] results Vector of classification results to be displayed. + * @param[in] infTimeMs Inference time in milliseconds, if available + * otherwise, this can be passed in as 0. + * @return true if successful, false otherwise. + **/ + static bool _PresentInferenceResult( + hal_platform& platform, + const std::vector& results); + + /* Audio inference classification handler. */ + bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) + { + constexpr uint32_t dataPsnTxtInfStartX = 20; + constexpr uint32_t dataPsnTxtInfStartY = 40; + + auto& platform = ctx.Get("platform"); + platform.data_psn->clear(COLOR_BLACK); + + /* If the request has a valid size, set the audio index. */ + if (clipIndex < NUMBER_OF_FILES) { + if (!_SetAppCtxClipIdx(ctx, clipIndex)) { + return false; + } + } + + /* Get model reference. */ + auto& model = ctx.Get("model"); + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + /* Get score threshold to be applied for the classifier (post-inference). */ + auto scoreThreshold = ctx.Get("scoreThreshold"); + + /* Get tensors. Dimensions of the tensor should have been verified by + * the callee. */ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; + + /* Populate MFCC related parameters. */ + auto mfccParamsWinLen = ctx.Get("frameLength"); + auto mfccParamsWinStride = ctx.Get("frameStride"); + + /* Populate ASR inference context and inner lengths for input. */ + auto inputCtxLen = ctx.Get("ctxLen"); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + + /* Audio data stride corresponds to inputInnerLen feature vectors. */ + const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen); + const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride; + const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); + + /* Get pre/post-processing objects. */ + auto& prep = ctx.Get("preprocess"); + auto& postp = ctx.Get("postprocess"); + + /* Set default reduction axis for post-processing. */ + const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + + /* Audio clip start index. */ + auto startClipIdx = ctx.Get("clipIndex"); + + /* Loop to process audio clips. */ + do { + /* Get current audio clip index. */ + auto currentIndex = ctx.Get("clipIndex"); + + /* Get the current audio buffer and respective size. */ + const int16_t* audioArr = get_audio_array(currentIndex); + const uint32_t audioArrSize = get_audio_array_size(currentIndex); + + if (!audioArr) { + printf_err("Invalid audio array pointer\n"); + return false; + } + + /* Audio clip must have enough samples to produce 1 MFCC feature. */ + if (audioArrSize < mfccParamsWinLen) { + printf_err("Not enough audio samples, minimum needed is %u\n", mfccParamsWinLen); + return false; + } + + /* Initialise an audio slider. */ + auto audioDataSlider = audio::ASRSlidingWindow( + audioArr, + audioArrSize, + audioParamsWinLen, + audioParamsWinStride); + + /* Declare a container for results. */ + std::vector results; + + /* Display message on the LCD - inference running. */ + std::string str_inf{"Running inference... "}; + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + info("Running inference on audio clip %u => %s\n", currentIndex, + get_filename(currentIndex)); + + size_t inferenceWindowLen = audioParamsWinLen; + + /* Start sliding through audio clip. */ + while (audioDataSlider.HasNext()) { + + /* If not enough audio see how much can be sent for processing. */ + size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); + if (nextStartIndex + audioParamsWinLen > audioArrSize) { + inferenceWindowLen = audioArrSize - nextStartIndex; + } + + const int16_t* inferenceWindow = audioDataSlider.Next(); + + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + static_cast(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); + + Profiler prepProfiler{&platform, "pre-processing"}; + prepProfiler.StartProfiling(); + + /* Calculate MFCCs, deltas and populate the input tensor. */ + prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor); + + prepProfiler.StopProfiling(); + std::string prepProfileResults = prepProfiler.GetResultsAndReset(); + info("%s\n", prepProfileResults.c_str()); + + /* Run inference over this audio clip sliding window. */ + arm::app::RunInference(platform, model); + + /* Post-process. */ + postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext()); + + /* Get results. */ + std::vector classificationResult; + auto& classifier = ctx.Get("classifier"); + classifier.GetClassificationResults( + outputTensor, classificationResult, + ctx.Get&>("labels"), 1); + + results.emplace_back(asr::AsrResult(classificationResult, + (audioDataSlider.Index() * + audioParamsSecondsPerSample * + audioParamsWinStride), + audioDataSlider.Index(), scoreThreshold)); + +#if VERIFY_TEST_OUTPUT + arm::app::DumpTensor(outputTensor, + outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); +#endif /* VERIFY_TEST_OUTPUT */ + + } + + /* Erase. */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + ctx.Set>("results", results); + + if (!_PresentInferenceResult(platform, results)) { + return false; + } + + _IncrementAppCtxClipIdx(ctx); + + } while (runAll && ctx.Get("clipIndex") != startClipIdx); + + return true; + } + + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx) + { + auto curAudioIdx = ctx.Get("clipIndex"); + + if (curAudioIdx + 1 >= NUMBER_OF_FILES) { + ctx.Set("clipIndex", 0); + return; + } + ++curAudioIdx; + ctx.Set("clipIndex", curAudioIdx); + } + + static bool _SetAppCtxClipIdx(ApplicationContext& ctx, const uint32_t idx) + { + if (idx >= NUMBER_OF_FILES) { + printf_err("Invalid idx %u (expected less than %u)\n", + idx, NUMBER_OF_FILES); + return false; + } + + ctx.Set("clipIndex", idx); + return true; + } + + static bool _PresentInferenceResult(hal_platform& platform, + const std::vector& results) + { + constexpr uint32_t dataPsnTxtStartX1 = 20; + constexpr uint32_t dataPsnTxtStartY1 = 60; + constexpr bool allow_multiple_lines = true; + + platform.data_psn->set_text_color(COLOR_GREEN); + + /* Results from multiple inferences should be combined before processing. */ + std::vector combinedResults; + for (auto& result : results) { + combinedResults.insert(combinedResults.end(), + result.m_resultVec.begin(), + result.m_resultVec.end()); + } + + /* Get each inference result string using the decoder. */ + for (const auto & result : results) { + std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); + + info("Result for inf %u: %s\n", result.m_inferenceNumber, + infResultStr.c_str()); + } + + /* Get the decoded result for the combined result. */ + std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); + + platform.data_psn->present_data_text( + finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, + allow_multiple_lines); + + info("Final result: %s\n", finalResultStr.c_str()); + return true; + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/asr/src/Wav2LetterMfcc.cc b/source/use_case/asr/src/Wav2LetterMfcc.cc new file mode 100644 index 0000000..92c91bc --- /dev/null +++ b/source/use_case/asr/src/Wav2LetterMfcc.cc @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterMfcc.hpp" + +#include "PlatformMath.hpp" + +#include + +namespace arm { +namespace app { +namespace audio { + + bool Wav2LetterMFCC::ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) + { + const size_t numBanks = melEnergies.size(); + + if (numBanks != filterBankFilterFirst.size() || + numBanks != filterBankFilterLast.size()) { + printf_err("Unexpected filter bank lengths\n"); + return false; + } + + for (size_t bin = 0; bin < numBanks; ++bin) { + auto filterBankIter = melFilterBank[bin].begin(); + float melEnergy = 1e-10; /* Avoid log of zero at later stages, same value used in librosa. */ + const int32_t firstIndex = filterBankFilterFirst[bin]; + const int32_t lastIndex = filterBankFilterLast[bin]; + + for (int32_t i = firstIndex; i <= lastIndex; ++i) { + melEnergy += (*filterBankIter++ * fftVec[i]); + } + + melEnergies[bin] = melEnergy; + } + + return true; + } + + void Wav2LetterMFCC::ConvertToLogarithmicScale( + std::vector& melEnergies) + { + float maxMelEnergy = -FLT_MAX; + + /* Container for natural logarithms of mel energies. */ + std::vector vecLogEnergies(melEnergies.size(), 0.f); + + /* Because we are taking natural logs, we need to multiply by log10(e). + * Also, for wav2letter model, we scale our log10 values by 10. */ + constexpr float multiplier = 10.0 * /* Default scalar. */ + 0.4342944819032518; /* log10f(std::exp(1.0)) */ + + /* Take log of the whole vector. */ + math::MathUtils::VecLogarithmF32(melEnergies, vecLogEnergies); + + /* Scale the log values and get the max. */ + for (auto iterM = melEnergies.begin(), iterL = vecLogEnergies.begin(); + iterM != melEnergies.end(); ++iterM, ++iterL) { + + *iterM = *iterL * multiplier; + + /* Save the max mel energy. */ + if (*iterM > maxMelEnergy) { + maxMelEnergy = *iterM; + } + } + + /* Clamp the mel energies. */ + constexpr float maxDb = 80.0; + const float clampLevelLowdB = maxMelEnergy - maxDb; + for (auto iter = melEnergies.begin(); iter != melEnergies.end(); ++iter) { + *iter = std::max(*iter, clampLevelLowdB); + } + } + + std::vector Wav2LetterMFCC::CreateDCTMatrix( + const int32_t inputLength, + const int32_t coefficientCount) + { + std::vector dctMatix(inputLength * coefficientCount); + + /* Orthonormal normalization. */ + const float normalizerK0 = 2 * math::MathUtils::SqrtF32(1.0f / + static_cast(4*inputLength)); + const float normalizer = 2 * math::MathUtils::SqrtF32(1.0f / + static_cast(2*inputLength)); + + const float angleIncr = M_PI / inputLength; + float angle = angleIncr; /* We start using it at k = 1 loop. */ + + /* First row of DCT will use normalizer K0. */ + for (int32_t n = 0; n < inputLength; ++n) { + dctMatix[n] = normalizerK0 /* cos(0) = 1 */; + } + + /* Second row (index = 1) onwards, we use standard normalizer. */ + for (int32_t k = 1, m = inputLength; k < coefficientCount; ++k, m += inputLength) { + for (int32_t n = 0; n < inputLength; ++n) { + dctMatix[m+n] = normalizer * + math::MathUtils::CosineF32((n + 0.5f) * angle); + } + angle += angleIncr; + } + return dctMatix; + } + + float Wav2LetterMFCC::GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + const bool useHTKMethod) + { + /* Slaney normalization for mel weights. */ + return (2.0f / (MFCC::InverseMelScale(rightMel, useHTKMethod) - + MFCC::InverseMelScale(leftMel, useHTKMethod))); + } + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/asr/src/Wav2LetterModel.cc b/source/use_case/asr/src/Wav2LetterModel.cc new file mode 100644 index 0000000..5aefecd --- /dev/null +++ b/source/use_case/asr/src/Wav2LetterModel.cc @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterModel.hpp" + +#include "hal.h" + +const tflite::MicroOpResolver& arm::app::Wav2LetterModel::GetOpResolver() +{ + return this->_m_opResolver; +} + +bool arm::app::Wav2LetterModel::EnlistOperations() +{ + this->_m_opResolver.AddConv2D(); + this->_m_opResolver.AddMul(); + this->_m_opResolver.AddMaximum(); + this->_m_opResolver.AddReshape(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->_m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + + return true; +} + +extern uint8_t* GetModelPointer(); +const uint8_t* arm::app::Wav2LetterModel::ModelPointer() +{ + return GetModelPointer(); +} + +extern size_t GetModelLen(); +size_t arm::app::Wav2LetterModel::ModelSize() +{ + return GetModelLen(); +} \ No newline at end of file diff --git a/source/use_case/asr/src/Wav2LetterPostprocess.cc b/source/use_case/asr/src/Wav2LetterPostprocess.cc new file mode 100644 index 0000000..60ee51e --- /dev/null +++ b/source/use_case/asr/src/Wav2LetterPostprocess.cc @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterPostprocess.hpp" + +#include "Wav2LetterModel.hpp" + + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + Postprocess::Postprocess(const uint32_t contextLen, + const uint32_t innerLen, + const uint32_t blankTokenIdx) + : _m_contextLen(contextLen), + _m_innerLen(innerLen), + _m_totalLen(2 * this->_m_contextLen + this->_m_innerLen), + _m_countIterations(0), + _m_blankTokenIdx(blankTokenIdx) + {} + + bool Postprocess::Invoke(TfLiteTensor* tensor, + const uint32_t axisIdx, + const bool lastIteration) + { + /* Basic checks. */ + if (!this->_IsInputValid(tensor, axisIdx)) { + return false; + } + + /* Irrespective of tensor type, we use unsigned "byte" */ + uint8_t* ptrData = tflite::GetTensorData(tensor); + const uint32_t elemSz = this->_GetTensorElementSize(tensor); + + /* Other sanity checks. */ + if (0 == elemSz) { + printf_err("Tensor type not supported for post processing\n"); + return false; + } else if (elemSz * this->_m_totalLen > tensor->bytes) { + printf_err("Insufficient number of tensor bytes\n"); + return false; + } + + /* Which axis do we need to process? */ + switch (axisIdx) { + case arm::app::Wav2LetterModel::ms_outputRowsIdx: + return this->_EraseSectionsRowWise(ptrData, + elemSz * tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx], + lastIteration); + case arm::app::Wav2LetterModel::ms_outputColsIdx: + return this->_EraseSectionsColWise(ptrData, + elemSz * tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx], + lastIteration); + default: + printf_err("Unsupported axis index: %u\n", axisIdx); + } + + return false; + } + + bool Postprocess::_IsInputValid(TfLiteTensor* tensor, + const uint32_t axisIdx) const + { + if (nullptr == tensor) { + return false; + } + + if (static_cast(axisIdx) >= tensor->dims->size) { + printf_err("Invalid axis index: %u; Max: %d\n", + axisIdx, tensor->dims->size); + return false; + } + + if (static_cast(this->_m_totalLen) != + tensor->dims->data[axisIdx]) { + printf_err("Unexpected tensor dimension for axis %d, \n", + tensor->dims->data[axisIdx]); + return false; + } + + return true; + } + + uint32_t Postprocess::_GetTensorElementSize(TfLiteTensor* tensor) + { + switch(tensor->type) { + case kTfLiteUInt8: + return 1; + case kTfLiteInt8: + return 1; + case kTfLiteInt16: + return 2; + case kTfLiteInt32: + return 4; + case kTfLiteFloat32: + return 4; + default: + printf_err("Unsupported tensor type %s\n", + TfLiteTypeGetName(tensor->type)); + } + + return 0; + } + + bool Postprocess::_EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) + { + /* In this case, the "zero-ing" is quite simple as the region + * to be zeroed sits in contiguous memory (row-major). */ + const uint32_t eraseLen = strideSzBytes * this->_m_contextLen; + + /* Erase left context? */ + if (this->_m_countIterations > 0) { + /* Set output of each classification window to the blank token. */ + std::memset(ptrData, 0, eraseLen); + for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) { + ptrData[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1; + } + } + + /* Erase right context? */ + if (false == lastIteration) { + uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->_m_contextLen + this->_m_innerLen)); + /* Set output of each classification window to the blank token. */ + std::memset(rightCtxPtr, 0, eraseLen); + for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) { + rightCtxPtr[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1; + } + } + + if (lastIteration) { + this->_m_countIterations = 0; + } else { + ++this->_m_countIterations; + } + + return true; + } + + bool Postprocess::_EraseSectionsColWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) + { + /* Not implemented. */ + UNUSED(ptrData); + UNUSED(strideSzBytes); + UNUSED(lastIteration); + return false; + } + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/asr/src/Wav2LetterPreprocess.cc b/source/use_case/asr/src/Wav2LetterPreprocess.cc new file mode 100644 index 0000000..e46cca3 --- /dev/null +++ b/source/use_case/asr/src/Wav2LetterPreprocess.cc @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterPreprocess.hpp" + +#include "PlatformMath.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include +#include + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + Preprocess::Preprocess( + const uint32_t numMfccFeatures, + const uint32_t windowLen, + const uint32_t windowStride, + const uint32_t numMfccVectors): + _m_mfcc(numMfccFeatures, windowLen), + _m_mfccBuf(numMfccFeatures, numMfccVectors), + _m_delta1Buf(numMfccFeatures, numMfccVectors), + _m_delta2Buf(numMfccFeatures, numMfccVectors), + _m_windowLen(windowLen), + _m_windowStride(windowStride), + _m_numMfccFeats(numMfccFeatures), + _m_numFeatVectors(numMfccVectors), + _m_window() + { + if (numMfccFeatures > 0 && windowLen > 0) { + this->_m_mfcc.Init(); + } + } + + bool Preprocess::Invoke( + const int16_t* audioData, + const uint32_t audioDataLen, + TfLiteTensor* tensor) + { + this->_m_window = SlidingWindow( + audioData, audioDataLen, + this->_m_windowLen, this->_m_windowStride); + + uint32_t mfccBufIdx = 0; + + std::fill(_m_mfccBuf.begin(), _m_mfccBuf.end(), 0.f); + std::fill(_m_delta1Buf.begin(), _m_delta1Buf.end(), 0.f); + std::fill(_m_delta2Buf.begin(), _m_delta2Buf.end(), 0.f); + + /* While we can slide over the window. */ + while (this->_m_window.HasNext()) { + const int16_t* mfccWindow = this->_m_window.Next(); + auto mfccAudioData = std::vector( + mfccWindow, + mfccWindow + this->_m_windowLen); + auto mfcc = this->_m_mfcc.MfccCompute(mfccAudioData); + for (size_t i = 0; i < this->_m_mfccBuf.size(0); ++i) { + this->_m_mfccBuf(i, mfccBufIdx) = mfcc[i]; + } + ++mfccBufIdx; + } + + /* Pad MFCC if needed by adding MFCC for zeros. */ + if (mfccBufIdx != this->_m_numFeatVectors) { + std::vector zerosWindow = std::vector(this->_m_windowLen, 0); + std::vector mfccZeros = this->_m_mfcc.MfccCompute(zerosWindow); + + while (mfccBufIdx != this->_m_numFeatVectors) { + memcpy(&this->_m_mfccBuf(0, mfccBufIdx), + mfccZeros.data(), sizeof(float) * _m_numMfccFeats); + ++mfccBufIdx; + } + } + + /* Compute first and second order deltas from MFCCs. */ + this->_ComputeDeltas(this->_m_mfccBuf, + this->_m_delta1Buf, + this->_m_delta2Buf); + + /* Normalise. */ + this->_Normalise(); + + /* Quantise. */ + QuantParams quantParams = GetTensorQuantParams(tensor); + + if (0 == quantParams.scale) { + printf_err("Quantisation scale can't be 0\n"); + return false; + } + + switch(tensor->type) { + case kTfLiteUInt8: + return this->_Quantise( + tflite::GetTensorData(tensor), tensor->bytes, + quantParams.scale, quantParams.offset); + case kTfLiteInt8: + return this->_Quantise( + tflite::GetTensorData(tensor), tensor->bytes, + quantParams.scale, quantParams.offset); + default: + printf_err("Unsupported tensor type %s\n", + TfLiteTypeGetName(tensor->type)); + } + + return false; + } + + bool Preprocess::_ComputeDeltas(Array2d& mfcc, + Array2d& delta1, + Array2d& delta2) + { + const std::vector delta1Coeffs = + {6.66666667e-02, 5.00000000e-02, 3.33333333e-02, + 1.66666667e-02, -3.46944695e-18, -1.66666667e-02, + -3.33333333e-02, -5.00000000e-02, -6.66666667e-02}; + + const std::vector delta2Coeffs = + {0.06060606, 0.01515152, -0.01731602, + -0.03679654, -0.04329004, -0.03679654, + -0.01731602, 0.01515152, 0.06060606}; + + if (delta1.size(0) == 0 || delta2.size(0) != delta1.size(0) || + mfcc.size(0) == 0 || mfcc.size(1) == 0) { + return false; + } + + /* Get the middle index; coeff vec len should always be odd. */ + const size_t coeffLen = delta1Coeffs.size(); + const size_t fMidIdx = (coeffLen - 1)/2; + const size_t numFeatures = mfcc.size(0); + const size_t numFeatVectors = mfcc.size(1); + + /* Iterate through features in MFCC vector. */ + for (size_t i = 0; i < numFeatures; ++i) { + /* For each feature, iterate through time (t) samples representing feature evolution and + * calculate d/dt and d^2/dt^2, using 1D convolution with differential kernels. + * Convolution padding = valid, result size is `time length - kernel length + 1`. + * The result is padded with 0 from both sides to match the size of initial time samples data. + * + * For the small filter, conv1D implementation as a simple loop is efficient enough. + * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32. + */ + + for (size_t j = fMidIdx; j < numFeatVectors - fMidIdx; ++j) { + float d1 = 0; + float d2 = 0; + const size_t mfccStIdx = j - fMidIdx; + + for (size_t k = 0, m = coeffLen - 1; k < coeffLen; ++k, --m) { + + d1 += mfcc(i,mfccStIdx + k) * delta1Coeffs[m]; + d2 += mfcc(i,mfccStIdx + k) * delta2Coeffs[m]; + } + + delta1(i,j) = d1; + delta2(i,j) = d2; + } + } + + return true; + } + + float Preprocess::_GetMean(Array2d& vec) + { + return math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + } + + float Preprocess::_GetStdDev(Array2d& vec, const float mean) + { + return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); + } + + void Preprocess::_NormaliseVec(Array2d& vec) + { + auto mean = Preprocess::_GetMean(vec); + auto stddev = Preprocess::_GetStdDev(vec, mean); + + debug("Mean: %f, Stddev: %f\n", mean, stddev); + if (stddev == 0) { + std::fill(vec.begin(), vec.end(), 0); + } else { + const float stddevInv = 1.f/stddev; + const float normalisedMean = mean/stddev; + + auto NormalisingFunction = [=](float& value) { + value = value * stddevInv - normalisedMean; + }; + std::for_each(vec.begin(), vec.end(), NormalisingFunction); + } + } + + void Preprocess::_Normalise() + { + Preprocess::_NormaliseVec(this->_m_mfccBuf); + Preprocess::_NormaliseVec(this->_m_delta1Buf); + Preprocess::_NormaliseVec(this->_m_delta2Buf); + } + + float Preprocess::_GetQuantElem( + const float elem, + const float quantScale, + const int quantOffset, + const float minVal, + const float maxVal) + { + float val = std::round((elem/quantScale) + quantOffset); + return std::min(std::max(val, minVal), maxVal); + } + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/asr/usecase.cmake b/source/use_case/asr/usecase.cmake new file mode 100644 index 0000000..e4b8752 --- /dev/null +++ b/source/use_case/asr/usecase.cmake @@ -0,0 +1,164 @@ +#---------------------------------------------------------------------------- +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#---------------------------------------------------------------------------- + +# If the path to a directory or source file has been defined, +# get the type here (FILEPATH or PATH): +if (DEFINED ${use_case}_FILE_PATH) + get_path_type(${${use_case}_FILE_PATH} PATH_TYPE) + + # Set the default type if path is not a dir or file path (or undefined) + if (NOT ${PATH_TYPE} STREQUAL PATH AND NOT ${PATH_TYPE} STREQUAL FILEPATH) + message(FATAL_ERROR "Invalid ${use_case}_FILE_PATH. It should be a dir or file path.") + endif() +else() + # Default is a directory path + set(PATH_TYPE PATH) +endif() + +message(STATUS "${use_case}_FILE_PATH is of type: ${PATH_TYPE}") + +USER_OPTION(${use_case}_FILE_PATH "Directory with custom WAV input files, or path to a single WAV file, to use in the evaluation application." + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/samples/ + ${PATH_TYPE}) + +USER_OPTION(${use_case}_LABELS_TXT_FILE "Labels' txt file for the chosen model." + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/labels_wav2letter.txt + FILEPATH) + +USER_OPTION(${use_case}_AUDIO_RATE "Specify the target sampling rate. Default is 16000." + 16000 + STRING) + +USER_OPTION(${use_case}_AUDIO_MONO "Specify if the audio needs to be converted to mono. Default is ON." + ON + BOOL) + +USER_OPTION(${use_case}_AUDIO_OFFSET "Specify the offset to start reading after this time (in seconds). Default is 0." + 0 + STRING) + +USER_OPTION(${use_case}_AUDIO_DURATION "Specify the audio duration to load (in seconds). If set to 0 the entire audio will be processed." + 0 + STRING) + +USER_OPTION(${use_case}_AUDIO_RES_TYPE "Specify re-sampling algorithm to use. By default is 'kaiser_best'." + kaiser_best + STRING) + +USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples to use. By default is 16000, if the audio is shorter will be automatically padded." + 16000 + STRING) + +USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD "Specify the score threshold [0.0, 1.0) that must be applied to the inference results for a label to be deemed valid." + 0.5 + STRING) + +# Generate input files +generate_audio_code(${${use_case}_FILE_PATH} ${SRC_GEN_DIR} ${INC_GEN_DIR} + ${${use_case}_AUDIO_RATE} + ${${use_case}_AUDIO_MONO} + ${${use_case}_AUDIO_OFFSET} + ${${use_case}_AUDIO_DURATION} + ${${use_case}_AUDIO_RES_TYPE} + ${${use_case}_AUDIO_MIN_SAMPLES}) + +# Generate labels file +set(${use_case}_LABELS_CPP_FILE Labels) +generate_labels_code( + INPUT "${${use_case}_LABELS_TXT_FILE}" + DESTINATION_SRC ${SRC_GEN_DIR} + DESTINATION_HDR ${INC_GEN_DIR} + OUTPUT_FILENAME "${${use_case}_LABELS_CPP_FILE}" +) + + +USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen model" + 0x00200000 + STRING) + + +# If there is no tflite file pointed to +if (NOT DEFINED ${use_case}_MODEL_TFLITE_PATH) + + set(MODEL_FILENAME wav2letter_int8.tflite) + set(MODEL_RESOURCES_DIR ${DOWNLOAD_DEP_DIR}/${use_case}) + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR}) + set(DEFAULT_MODEL_PATH ${MODEL_RESOURCES_DIR}/${MODEL_FILENAME}) + + # Download the default model + set(ZOO_COMMON_SUBPATH "models/speech_recognition/wav2letter/tflite_int8") + set(ZOO_MODEL_SUBPATH "${ZOO_COMMON_SUBPATH}/${MODEL_FILENAME}") + + download_file_from_modelzoo(${ZOO_MODEL_SUBPATH} ${DEFAULT_MODEL_PATH}) + + if (ETHOS_U55_ENABLED) + message(STATUS + "Ethos-U55 is enabled, but the model downloaded is not optimized by vela. " + "To use Ethos-U55 acceleration, optimise the downloaded model and pass it " + "as ${use_case}_MODEL_TFLITE_PATH to the CMake configuration.") + endif() + + # If the target platform is native + if (${TARGET_PLATFORM} STREQUAL native) + + # Download test vectors + set(ZOO_TEST_IFM_SUBPATH "${ZOO_COMMON_SUBPATH}/testing_input/input_2_int8/0.npy") + set(ZOO_TEST_OFM_SUBPATH "${ZOO_COMMON_SUBPATH}/testing_output/Identity_int8/0.npy") + + set(${use_case}_TEST_IFM ${MODEL_RESOURCES_DIR}/ifm0.npy CACHE FILEPATH + "Input test vector for ${use_case}") + set(${use_case}_TEST_OFM ${MODEL_RESOURCES_DIR}/ofm0.npy CACHE FILEPATH + "Input test vector for ${use_case}") + + download_file_from_modelzoo(${ZOO_TEST_IFM_SUBPATH} ${${use_case}_TEST_IFM}) + download_file_from_modelzoo(${ZOO_TEST_OFM_SUBPATH} ${${use_case}_TEST_OFM}) + + set(TEST_SRC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/src) + set(TEST_INC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/include) + file(MAKE_DIRECTORY ${TEST_SRC_GEN_DIR} ${TEST_INC_GEN_DIR}) + + # Generate test data files to be included in x86 tests + generate_test_data_code( + INPUT_DIR "${DOWNLOAD_DEP_DIR}/${use_case}" + DESTINATION_SRC ${TEST_SRC_GEN_DIR} + DESTINATION_HDR ${TEST_INC_GEN_DIR} + USECASE "${use_case}") + endif() + +else() + set(DEFAULT_MODEL_PATH "N/A") +endif() + +set(EXTRA_MODEL_CODE + "/* Model parameters for ${use_case} */" + "extern const int g_FrameLength = 512" + "extern const int g_FrameStride = 160" + "extern const int g_ctxLen = 98" + "extern const float g_ScoreThreshold = ${${use_case}_MODEL_SCORE_THRESHOLD}" + ) + +USER_OPTION(${use_case}_MODEL_TFLITE_PATH "NN models file to be used in the evaluation application. Model files must be in tflite format." + ${DEFAULT_MODEL_PATH} + FILEPATH + ) + +# Generate model file +generate_tflite_code( + MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH} + DESTINATION ${SRC_GEN_DIR} + EXPRESSIONS ${EXTRA_MODEL_CODE} + ) diff --git a/source/use_case/img_class/include/MobileNetModel.hpp b/source/use_case/img_class/include/MobileNetModel.hpp new file mode 100644 index 0000000..f0521ce --- /dev/null +++ b/source/use_case/img_class/include/MobileNetModel.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef IMG_CLASS_MOBILENETMODEL_HPP +#define IMG_CLASS_MOBILENETMODEL_HPP + +#include "Model.hpp" + +namespace arm { +namespace app { + + class MobileNetModel : public Model { + + public: + /* Indices for the expected model - based on input tensor shape */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_inputChannelsIdx = 3; + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int _ms_maxOpCnt = 7; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver<_ms_maxOpCnt> _m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* IMG_CLASS_MOBILENETMODEL_HPP */ \ No newline at end of file diff --git a/source/use_case/img_class/include/UseCaseHandler.hpp b/source/use_case/img_class/include/UseCaseHandler.hpp new file mode 100644 index 0000000..a6cf104 --- /dev/null +++ b/source/use_case/img_class/include/UseCaseHandler.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef IMG_CLASS_EVT_HANDLER_HPP +#define IMG_CLASS_EVT_HANDLER_HPP + +#include "AppContext.hpp" + +namespace arm { +namespace app { + + /** + * @brief Handles the inference event. + * @param[in] ctx Pointer to the application context. + * @param[in] imgIndex Index to the image to classify. + * @param[in] runAll Flag to request classification of all the available images. + * @return true or false based on execution success. + **/ + bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* IMG_CLASS_EVT_HANDLER_HPP */ \ No newline at end of file diff --git a/source/use_case/img_class/src/MainLoop.cc b/source/use_case/img_class/src/MainLoop.cc new file mode 100644 index 0000000..469907c --- /dev/null +++ b/source/use_case/img_class/src/MainLoop.cc @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "hal.h" /* Brings in platform definitions. */ +#include "Classifier.hpp" /* Classifier. */ +#include "InputFiles.hpp" /* For input images. */ +#include "Labels.hpp" /* For label strings. */ +#include "MobileNetModel.hpp" /* Model class for running inference. */ +#include "UseCaseHandler.hpp" /* Handlers for different user options. */ +#include "UseCaseCommonUtils.hpp" /* Utils functions. */ + +using ImgClassClassifier = arm::app::Classifier; + +enum opcodes +{ + MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ + MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */ + MENU_OPT_RUN_INF_ALL, /* Run inference on all. */ + MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */ + MENU_OPT_LIST_IMAGES /* List the current baked images. */ +}; + +static void DisplayMenu() +{ + printf("\n\nUser input required\n"); + printf("Enter option number from:\n\n"); + printf(" %u. Classify next image\n", MENU_OPT_RUN_INF_NEXT); + printf(" %u. Classify image at chosen index\n", MENU_OPT_RUN_INF_CHOSEN); + printf(" %u. Run classification on all images\n", MENU_OPT_RUN_INF_ALL); + printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO); + printf(" %u. List images\n\n", MENU_OPT_LIST_IMAGES); + printf(" Choice: "); +} + +void main_loop(hal_platform& platform) +{ + arm::app::MobileNetModel model; /* Model wrapper object. */ + + /* Load the model. */ + if (!model.Init()) { + printf_err("Failed to initialise model\n"); + return; + } + + /* Instantiate application context. */ + arm::app::ApplicationContext caseContext; + + caseContext.Set("platform", platform); + caseContext.Set("model", model); + caseContext.Set("imgIndex", 0); + + ImgClassClassifier classifier; /* Classifier wrapper object. */ + caseContext.Set("classifier", classifier); + + std::vector labels; + GetLabelsVector(labels); + caseContext.Set&>("labels", labels); + + /* Loop. */ + bool executionSuccessful = true; + constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; + + /* Loop. */ + do { + int menuOption = MENU_OPT_RUN_INF_NEXT; + if (bUseMenu) { + DisplayMenu(); + menuOption = arm::app::ReadUserInputAsInt(platform); + printf("\n"); + } + switch (menuOption) { + case MENU_OPT_RUN_INF_NEXT: + executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get("imgIndex"), false); + break; + case MENU_OPT_RUN_INF_CHOSEN: { + printf(" Enter the image index [0, %d]: ", NUMBER_OF_FILES-1); + auto imgIndex = static_cast(arm::app::ReadUserInputAsInt(platform)); + executionSuccessful = ClassifyImageHandler(caseContext, imgIndex, false); + break; + } + case MENU_OPT_RUN_INF_ALL: + executionSuccessful = ClassifyImageHandler(caseContext, caseContext.Get("imgIndex"), true); + break; + case MENU_OPT_SHOW_MODEL_INFO: + executionSuccessful = model.ShowModelInfoHandler(); + break; + case MENU_OPT_LIST_IMAGES: + executionSuccessful = ListFilesHandler(caseContext); + break; + default: + printf("Incorrect choice, try again."); + break; + } + } while (executionSuccessful && bUseMenu); + info("Main loop terminated.\n"); +} \ No newline at end of file diff --git a/source/use_case/img_class/src/MobileNetModel.cc b/source/use_case/img_class/src/MobileNetModel.cc new file mode 100644 index 0000000..eeaa109 --- /dev/null +++ b/source/use_case/img_class/src/MobileNetModel.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MobileNetModel.hpp" + +#include "hal.h" + +const tflite::MicroOpResolver& arm::app::MobileNetModel::GetOpResolver() +{ + return this->_m_opResolver; +} + +bool arm::app::MobileNetModel::EnlistOperations() +{ + this->_m_opResolver.AddDepthwiseConv2D(); + this->_m_opResolver.AddConv2D(); + this->_m_opResolver.AddAveragePool2D(); + this->_m_opResolver.AddAdd(); + this->_m_opResolver.AddReshape(); + this->_m_opResolver.AddSoftmax(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->_m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +extern uint8_t* GetModelPointer(); +const uint8_t* arm::app::MobileNetModel::ModelPointer() +{ + return GetModelPointer(); +} + +extern size_t GetModelLen(); +size_t arm::app::MobileNetModel::ModelSize() +{ + return GetModelLen(); +} \ No newline at end of file diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc new file mode 100644 index 0000000..a412fec --- /dev/null +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseHandler.hpp" + +#include "Classifier.hpp" +#include "InputFiles.hpp" +#include "MobileNetModel.hpp" +#include "UseCaseCommonUtils.hpp" +#include "hal.h" + +using ImgClassClassifier = arm::app::Classifier; + +namespace arm { +namespace app { + + /** + * @brief Helper function to load the current image into the input + * tensor. + * @param[in] imIdx Image index (from the pool of images available + * to the application). + * @param[out] inputTensor Pointer to the input tensor to be populated. + * @return true if tensor is loaded, false otherwise. + **/ + static bool _LoadImageIntoTensor(uint32_t imIdx, TfLiteTensor* inputTensor); + + /** + * @brief Helper function to increment current image index. + * @param[in,out] ctx Pointer to the application context object. + **/ + static void _IncrementAppCtxImageIdx(ApplicationContext& ctx); + + /** + * @brief Helper function to set the image index. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] idx Value to be set. + * @return true if index is set, false otherwise. + **/ + static bool _SetAppCtxImageIdx(ApplicationContext& ctx, uint32_t idx); + + /** + * @brief Presents inference results using the data presentation + * object. + * @param[in] platform Reference to the hal platform object. + * @param[in] results Vector of classification results to be displayed. + * @param[in] infTimeMs Inference time in milliseconds, if available + * otherwise, this can be passed in as 0. + * @return true if successful, false otherwise. + **/ + static bool _PresentInferenceResult(hal_platform& platform, + const std::vector& results); + + /** + * @brief Helper function to convert a UINT8 image to INT8 format. + * @param[in,out] data Pointer to the data start. + * @param[in] kMaxImageSize Total number of pixels in the image. + **/ + static void ConvertImgToInt8(void* data, size_t kMaxImageSize); + + /* Image inference classification handler. */ + bool ClassifyImageHandler(ApplicationContext& ctx, uint32_t imgIndex, bool runAll) + { + auto& platform = ctx.Get("platform"); + + constexpr uint32_t dataPsnImgDownscaleFactor = 2; + constexpr uint32_t dataPsnImgStartX = 10; + constexpr uint32_t dataPsnImgStartY = 35; + + constexpr uint32_t dataPsnTxtInfStartX = 150; + constexpr uint32_t dataPsnTxtInfStartY = 40; + + platform.data_psn->clear(COLOR_BLACK); + + auto& model = ctx.Get("model"); + + /* If the request has a valid size, set the image index. */ + if (imgIndex < NUMBER_OF_FILES) { + if (!_SetAppCtxImageIdx(ctx, imgIndex)) { + return false; + } + } + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + auto curImIdx = ctx.Get("imgIndex"); + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); + + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < 3) { + printf_err("Input tensor dimension should be >= 3\n"); + return false; + } + + TfLiteIntArray* inputShape = model.GetInputShape(0); + + const uint32_t nCols = inputShape->data[arm::app::MobileNetModel::ms_inputColsIdx]; + const uint32_t nRows = inputShape->data[arm::app::MobileNetModel::ms_inputRowsIdx]; + const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx]; + + std::vector results; + + do { + /* Strings for presentation/logging. */ + std::string str_inf{"Running inference... "}; + + /* Copy over the data. */ + _LoadImageIntoTensor(ctx.Get("imgIndex"), inputTensor); + + /* Display this image on the LCD. */ + platform.data_psn->present_data_image( + (uint8_t*) inputTensor->data.data, + nCols, nRows, nChannels, + dataPsnImgStartX, dataPsnImgStartY, dataPsnImgDownscaleFactor); + + /* If the data is signed. */ + if (model.IsDataSigned()) { + ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + } + + /* Display message on the LCD - inference running. */ + platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + /* Run inference over this image. */ + info("Running inference on image %u => %s\n", ctx.Get("imgIndex"), + get_filename(ctx.Get("imgIndex"))); + + RunInference(platform, model); + + /* Erase. */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text(str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + auto& classifier = ctx.Get("classifier"); + classifier.GetClassificationResults(outputTensor, results, + ctx.Get&>("labels"), + 5); + + /* Add results to context for access outside handler. */ + ctx.Set>("results", results); + +#if VERIFY_TEST_OUTPUT + arm::app::DumpTensor(outputTensor); +#endif /* VERIFY_TEST_OUTPUT */ + + if (!_PresentInferenceResult(platform, results)) { + return false; + } + + _IncrementAppCtxImageIdx(ctx); + + } while (runAll && ctx.Get("imgIndex") != curImIdx); + + return true; + } + + static bool _LoadImageIntoTensor(const uint32_t imIdx, TfLiteTensor* inputTensor) + { + const size_t copySz = inputTensor->bytes < IMAGE_DATA_SIZE ? + inputTensor->bytes : IMAGE_DATA_SIZE; + const uint8_t* imgSrc = get_img_array(imIdx); + if (nullptr == imgSrc) { + printf_err("Failed to get image index %u (max: %u)\n", imIdx, + NUMBER_OF_FILES - 1); + return false; + } + + memcpy(inputTensor->data.data, imgSrc, copySz); + debug("Image %u loaded\n", imIdx); + return true; + } + + static void _IncrementAppCtxImageIdx(ApplicationContext& ctx) + { + auto curImIdx = ctx.Get("imgIndex"); + + if (curImIdx + 1 >= NUMBER_OF_FILES) { + ctx.Set("imgIndex", 0); + return; + } + ++curImIdx; + ctx.Set("imgIndex", curImIdx); + } + + static bool _SetAppCtxImageIdx(ApplicationContext& ctx, const uint32_t idx) + { + if (idx >= NUMBER_OF_FILES) { + printf_err("Invalid idx %u (expected less than %u)\n", + idx, NUMBER_OF_FILES); + return false; + } + ctx.Set("imgIndex", idx); + return true; + } + + static bool _PresentInferenceResult(hal_platform& platform, + const std::vector& results) + { + constexpr uint32_t dataPsnTxtStartX1 = 150; + constexpr uint32_t dataPsnTxtStartY1 = 30; + + constexpr uint32_t dataPsnTxtStartX2 = 10; + constexpr uint32_t dataPsnTxtStartY2 = 150; + + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + + platform.data_psn->set_text_color(COLOR_GREEN); + + /* Display each result. */ + uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; + uint32_t rowIdx2 = dataPsnTxtStartY2; + + for (uint32_t i = 0; i < results.size(); ++i) { + std::string resultStr = + std::to_string(i + 1) + ") " + + std::to_string(results[i].m_labelIdx) + + " (" + std::to_string(results[i].m_normalisedVal) + ")"; + + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); + rowIdx1 += dataPsnTxtYIncr; + + resultStr = std::to_string(i + 1) + ") " + results[i].m_label; + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX2, rowIdx2, 0); + rowIdx2 += dataPsnTxtYIncr; + + info("%u) %u (%f) -> %s\n", i, results[i].m_labelIdx, + results[i].m_normalisedVal, results[i].m_label.c_str()); + } + + return true; + } + + static void ConvertImgToInt8(void* data, const size_t kMaxImageSize) + { + auto* tmp_req_data = (uint8_t*) data; + auto* tmp_signed_req_data = (int8_t*) data; + + for (size_t i = 0; i < kMaxImageSize; i++) { + tmp_signed_req_data[i] = (int8_t) ( + (int32_t) (tmp_req_data[i]) - 128); + } + } + +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/img_class/usecase.cmake b/source/use_case/img_class/usecase.cmake new file mode 100644 index 0000000..440eabe --- /dev/null +++ b/source/use_case/img_class/usecase.cmake @@ -0,0 +1,125 @@ +#---------------------------------------------------------------------------- +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#---------------------------------------------------------------------------- + +# If the path to a directory or source file has been defined, +# get the type here (FILEPATH or PATH): +if (DEFINED ${use_case}_FILE_PATH) + get_path_type(${${use_case}_FILE_PATH} PATH_TYPE) + # Set the default type if path is not a dir or file path (or undefined) + if (NOT ${PATH_TYPE} STREQUAL PATH AND NOT ${PATH_TYPE} STREQUAL FILEPATH) + message(FATAL_ERROR "Invalid ${use_case}_FILE_PATH. It should be a dir or file path.") + endif() +else() + # Default is a directory path + set(PATH_TYPE PATH) +endif() + +message(STATUS "${use_case}_FILE_PATH is of type: ${PATH_TYPE}") + +USER_OPTION(${use_case}_FILE_PATH "Directory with custom image files to use, or path to a single image, in the evaluation application" + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/samples/ + ${PATH_TYPE}) + +USER_OPTION(${use_case}_IMAGE_SIZE "Square image size in pixels. Images will be resized to this size." + 224 + STRING) + +USER_OPTION(${use_case}_LABELS_TXT_FILE "Labels' txt file for the chosen model" + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/labels_mobilenet_v2_1.0_224.txt + FILEPATH) + +# Generate input files +generate_images_code("${${use_case}_FILE_PATH}" + ${SRC_GEN_DIR} + ${INC_GEN_DIR} + "${${use_case}_IMAGE_SIZE}") + +# Generate labels file +set(${use_case}_LABELS_CPP_FILE Labels) +generate_labels_code( + INPUT "${${use_case}_LABELS_TXT_FILE}" + DESTINATION_SRC ${SRC_GEN_DIR} + DESTINATION_HDR ${INC_GEN_DIR} + OUTPUT_FILENAME "${${use_case}_LABELS_CPP_FILE}" +) + +USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen model" + 0x00200000 + STRING) + +# If there is no tflite file pointed to +if (NOT DEFINED ${use_case}_MODEL_TFLITE_PATH) + + set(MODEL_RESOURCES_DIR ${DOWNLOAD_DEP_DIR}/${use_case}) + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR}) + set(MODEL_FILENAME mobilenet_v2_1.0_224_quantized_1_default_1.tflite) + set(DEFAULT_MODEL_PATH ${MODEL_RESOURCES_DIR}/${MODEL_FILENAME}) + + # Download the default model + set(ZOO_COMMON_SUBPATH "models/image_classification/mobilenet_v2_1.0_224/tflite_uint8") + set(ZOO_MODEL_SUBPATH "${ZOO_COMMON_SUBPATH}/${MODEL_FILENAME}") + + download_file_from_modelzoo(${ZOO_MODEL_SUBPATH} ${DEFAULT_MODEL_PATH}) + + if (ETHOS_U55_ENABLED) + message(STATUS + "Ethos-U55 is enabled, but the model downloaded is not optimized by vela. " + "To use Ethos-U55 acceleration, optimise the downloaded model and pass it " + "as ${use_case}_MODEL_TFLITE_PATH to the CMake configuration.") + endif() + + # If the target platform is native + if (${TARGET_PLATFORM} STREQUAL native) + + # Download test vectors + set(ZOO_TEST_IFM_SUBPATH "${ZOO_COMMON_SUBPATH}/testing_input/input/0.npy") + set(ZOO_TEST_OFM_SUBPATH "${ZOO_COMMON_SUBPATH}/testing_output/output/0.npy") + + set(${use_case}_TEST_IFM ${MODEL_RESOURCES_DIR}/ifm0.npy CACHE FILEPATH + "Input test vector for ${use_case}") + set(${use_case}_TEST_OFM ${MODEL_RESOURCES_DIR}/ofm0.npy CACHE FILEPATH + "Input test vector for ${use_case}") + + download_file_from_modelzoo(${ZOO_TEST_IFM_SUBPATH} ${${use_case}_TEST_IFM}) + download_file_from_modelzoo(${ZOO_TEST_OFM_SUBPATH} ${${use_case}_TEST_OFM}) + + set(TEST_SRC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/src) + set(TEST_INC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/include) + file(MAKE_DIRECTORY ${TEST_SRC_GEN_DIR} ${TEST_INC_GEN_DIR}) + + # Generate test data files to be included in x86 tests + generate_test_data_code( + INPUT_DIR "${DOWNLOAD_DEP_DIR}/${use_case}" + DESTINATION_SRC ${TEST_SRC_GEN_DIR} + DESTINATION_HDR ${TEST_INC_GEN_DIR} + USECASE "${use_case}") + endif() + +else() + set(DEFAULT_MODEL_PATH "N/A") +endif() + +USER_OPTION(${use_case}_MODEL_TFLITE_PATH "NN models file to be used in the evaluation application. Model files must be in tflite format." + ${DEFAULT_MODEL_PATH} + FILEPATH + ) + +# Generate model file +generate_tflite_code( + MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH} + DESTINATION ${SRC_GEN_DIR} + ) diff --git a/source/use_case/inference_runner/include/TestModel.hpp b/source/use_case/inference_runner/include/TestModel.hpp new file mode 100644 index 0000000..0b3e9b9 --- /dev/null +++ b/source/use_case/inference_runner/include/TestModel.hpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INF_RUNNER_TESTMODEL_HPP +#define INF_RUNNER_TESTMODEL_HPP + +#include "Model.hpp" + +namespace arm { +namespace app { + + class TestModel : public Model { + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::AllOpsResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance, not needed as using AllOpsResolver. */ + bool EnlistOperations() override {return false;} + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + + /* No need to define individual ops at the cost of extra memory. */ + tflite::AllOpsResolver _m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* INF_RUNNER_TESTMODEL_HPP */ \ No newline at end of file diff --git a/source/use_case/inference_runner/include/UseCaseHandler.hpp b/source/use_case/inference_runner/include/UseCaseHandler.hpp new file mode 100644 index 0000000..4962650 --- /dev/null +++ b/source/use_case/inference_runner/include/UseCaseHandler.hpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INF_RUNNER_EVT_HANDLER_HPP +#define INF_RUNNER_EVT_HANDLER_HPP + +#include "AppContext.hpp" + +namespace arm { +namespace app { + + /** + * @brief Handles the inference event. + * @param[in] ctx Pointer to the application context. + * @return true or false based on execution success. + **/ + bool RunInferenceHandler(ApplicationContext& ctx); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* INF_RUNNER_EVT_HANDLER_HPP */ \ No newline at end of file diff --git a/source/use_case/inference_runner/src/MainLoop.cc b/source/use_case/inference_runner/src/MainLoop.cc new file mode 100644 index 0000000..b110a24 --- /dev/null +++ b/source/use_case/inference_runner/src/MainLoop.cc @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "hal.h" /* Brings in platform definitions. */ +#include "TestModel.hpp" /* Model class for running inference. */ +#include "UseCaseHandler.hpp" /* Handlers for different user options. */ +#include "UseCaseCommonUtils.hpp" /* Utils functions. */ + +enum opcodes +{ + MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ + MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */ +}; + +void main_loop(hal_platform& platform) +{ + arm::app::TestModel model; /* Model wrapper object. */ + + /* Load the model. */ + if (!model.Init()) { + printf_err("Failed to initialise model\n"); + return; + } + + /* Instantiate application context. */ + arm::app::ApplicationContext caseContext; + + caseContext.Set("platform", platform); + caseContext.Set("model", model); + caseContext.Set("imgIndex", 0); + + /* Loop. */ + if (RunInferenceHandler(caseContext)) { + info("Inference completed.\n"); + } else { + printf_err("Inference failed.\n"); + } +} diff --git a/source/use_case/inference_runner/src/TestModel.cc b/source/use_case/inference_runner/src/TestModel.cc new file mode 100644 index 0000000..0926a96 --- /dev/null +++ b/source/use_case/inference_runner/src/TestModel.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "TestModel.hpp" + +#include "hal.h" + +const tflite::AllOpsResolver& arm::app::TestModel::GetOpResolver() +{ + return this->_m_opResolver; +} + +extern uint8_t* GetModelPointer(); +const uint8_t* arm::app::TestModel::ModelPointer() +{ + return GetModelPointer(); +} + +extern size_t GetModelLen(); +size_t arm::app::TestModel::ModelSize() +{ + return GetModelLen(); +} \ No newline at end of file diff --git a/source/use_case/inference_runner/src/UseCaseHandler.cc b/source/use_case/inference_runner/src/UseCaseHandler.cc new file mode 100644 index 0000000..ac4ea47 --- /dev/null +++ b/source/use_case/inference_runner/src/UseCaseHandler.cc @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseHandler.hpp" + +#include "TestModel.hpp" +#include "UseCaseCommonUtils.hpp" +#include "hal.h" + +#include + +namespace arm { +namespace app { + + bool RunInferenceHandler(ApplicationContext& ctx) + { + auto& platform = ctx.Get("platform"); + auto& model = ctx.Get("model"); + + constexpr uint32_t dataPsnTxtInfStartX = 150; + constexpr uint32_t dataPsnTxtInfStartY = 40; + + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + const size_t numInputs = model.GetNumInputs(); + + /* Populate each input tensor with random data. */ + for (size_t inputIndex = 0; inputIndex < numInputs; inputIndex++) { + + TfLiteTensor* inputTensor = model.GetInputTensor(inputIndex); + + debug("Populating input tensor %zu@%p\n", inputIndex, inputTensor); + debug("Total input size to be populated: %zu\n", inputTensor->bytes); + + /* Create a random input. */ + if (inputTensor->bytes > 0) { + + uint8_t* tData = tflite::GetTensorData(inputTensor); + + for (size_t j = 0; j < inputTensor->bytes; ++j) { + tData[j] = static_cast(std::rand() & 0xFF); + } + } + } + + /* Strings for presentation/logging. */ + std::string str_inf{"Running inference... "}; + + /* Display message on the LCD - inference running. */ + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + RunInference(platform, model); + + /* Erase. */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + +#if VERIFY_TEST_OUTPUT + for (size_t outputIndex = 0; outputIndex < model.GetNumOutputs(); outputIndex++) { + arm::app::DumpTensor(model.GetOutputTensor(outputIndex)); + } +#endif /* VERIFY_TEST_OUTPUT */ + + return true; + } + +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/inference_runner/usecase.cmake b/source/use_case/inference_runner/usecase.cmake new file mode 100644 index 0000000..77b1ae1 --- /dev/null +++ b/source/use_case/inference_runner/usecase.cmake @@ -0,0 +1,57 @@ +#---------------------------------------------------------------------------- +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#---------------------------------------------------------------------------- + +USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen model" + 0x00200000 + STRING) + +generate_default_input_code(${INC_GEN_DIR}) + +# If there is no tflite file pointed to +if (NOT DEFINED ${use_case}_MODEL_TFLITE_PATH) + + set(MODEL_RESOURCES_DIR ${DOWNLOAD_DEP_DIR}/${use_case}) + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR}) + set(MODEL_FILENAME dnn_s_quantized.tflite) + set(DEFAULT_MODEL_PATH ${MODEL_RESOURCES_DIR}/${MODEL_FILENAME}) + + # Download the default model + set(ZOO_COMMON_SUBPATH "models/keyword_spotting/dnn_small/tflite_int8/") + set(ZOO_MODEL_SUBPATH "${ZOO_COMMON_SUBPATH}/${MODEL_FILENAME}") + + download_file_from_modelzoo(${ZOO_MODEL_SUBPATH} ${DEFAULT_MODEL_PATH}) + + if (ETHOS_U55_ENABLED) + message(STATUS + "Ethos-U55 is enabled, but the model downloaded is not optimized by vela. " + "To use Ethos-U55 acceleration, optimise the downloaded model and pass it " + "as ${use_case}_MODEL_TFLITE_PATH to the CMake configuration.") + endif() + +else() + set(DEFAULT_MODEL_PATH "N/A") +endif() + +USER_OPTION(${use_case}_MODEL_TFLITE_PATH "NN models file to be used in the evaluation application. Model files must be in tflite format." + ${DEFAULT_MODEL_PATH} + FILEPATH) + +# Generate model file +generate_tflite_code( + MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH} + DESTINATION ${SRC_GEN_DIR} +) diff --git a/source/use_case/kws/include/DsCnnMfcc.hpp b/source/use_case/kws/include/DsCnnMfcc.hpp new file mode 100644 index 0000000..3f681af --- /dev/null +++ b/source/use_case/kws/include/DsCnnMfcc.hpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_DSCNN_MFCC_HPP +#define KWS_DSCNN_MFCC_HPP + +#include "Mfcc.hpp" + +namespace arm { +namespace app { +namespace audio { + + /* Class to provide DS-CNN specific MFCC calculation requirements. */ + class DsCnnMFCC : public MFCC { + + public: + static constexpr uint32_t ms_defaultSamplingFreq = 16000; + static constexpr uint32_t ms_defaultNumFbankBins = 40; + static constexpr uint32_t ms_defaultMelLoFreq = 20; + static constexpr uint32_t ms_defaultMelHiFreq = 4000; + static constexpr bool ms_defaultUseHtkMethod = true; + + explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen) + : MFCC(MfccParams( + ms_defaultSamplingFreq, ms_defaultNumFbankBins, + ms_defaultMelLoFreq, ms_defaultMelHiFreq, + numFeats, frameLen, ms_defaultUseHtkMethod)) + {} + DsCnnMFCC() = delete; + ~DsCnnMFCC() = default; + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_DSCNN_MFCC_HPP */ \ No newline at end of file diff --git a/source/use_case/kws/include/DsCnnModel.hpp b/source/use_case/kws/include/DsCnnModel.hpp new file mode 100644 index 0000000..a4e7110 --- /dev/null +++ b/source/use_case/kws/include/DsCnnModel.hpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_DSCNNMODEL_HPP +#define KWS_DSCNNMODEL_HPP + +#include "Model.hpp" + +extern const int g_FrameLength; +extern const int g_FrameStride; +extern const float g_ScoreThreshold; + +namespace arm { +namespace app { + + class DsCnnModel : public Model { + public: + /* Indices for the expected model - based on input and output tensor shapes */ + static constexpr uint32_t ms_inputRowsIdx = 2; + static constexpr uint32_t ms_inputColsIdx = 3; + static constexpr uint32_t ms_outputRowsIdx = 2; + static constexpr uint32_t ms_outputColsIdx = 3; + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int _ms_maxOpCnt = 8; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver<_ms_maxOpCnt> _m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_DSCNNMODEL_HPP */ diff --git a/source/use_case/kws/include/KwsResult.hpp b/source/use_case/kws/include/KwsResult.hpp new file mode 100644 index 0000000..5a26ce1 --- /dev/null +++ b/source/use_case/kws/include/KwsResult.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_RESULT_HPP +#define KWS_RESULT_HPP + +#include "ClassificationResult.hpp" + +#include + +namespace arm { +namespace app { +namespace kws { + + using ResultVec = std::vector < arm::app::ClassificationResult >; + + /* Structure for holding kws result. */ + class KwsResult { + + public: + ResultVec m_resultVec; /* Container for "thresholded" classification results. */ + float m_timeStamp; /* Audio timestamp for this result. */ + uint32_t m_inferenceNumber; /* Corresponding inference number. */ + float m_threshold; /* Threshold value for `m_resultVec`. */ + + KwsResult() = delete; + KwsResult(ResultVec& resultVec, + const float timestamp, + const uint32_t inferenceIdx, + const float scoreThreshold) { + + this->m_threshold = scoreThreshold; + this->m_timeStamp = timestamp; + this->m_inferenceNumber = inferenceIdx; + + this->m_resultVec = ResultVec(); + for (auto & i : resultVec) { + if (i.m_normalisedVal >= this->m_threshold) { + this->m_resultVec.emplace_back(i); + } + } + } + ~KwsResult() = default; + }; + +} /* namespace kws */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_RESULT_HPP */ \ No newline at end of file diff --git a/source/use_case/kws/include/UseCaseHandler.hpp b/source/use_case/kws/include/UseCaseHandler.hpp new file mode 100644 index 0000000..1eb742f --- /dev/null +++ b/source/use_case/kws/include/UseCaseHandler.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_EVT_HANDLER_HPP +#define KWS_EVT_HANDLER_HPP + +#include "AppContext.hpp" + +namespace arm { +namespace app { + + /** + * @brief Handles the inference event. + * @param[in] ctx Pointer to the application context. + * @param[in] clipIndex Index to the audio clip to classify. + * @param[in] runAll Flag to request classification of all the available audio clips. + * @return true or false based on execution success. + **/ + bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_EVT_HANDLER_HPP */ \ No newline at end of file diff --git a/source/use_case/kws/src/DsCnnModel.cc b/source/use_case/kws/src/DsCnnModel.cc new file mode 100644 index 0000000..a093eb4 --- /dev/null +++ b/source/use_case/kws/src/DsCnnModel.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "DsCnnModel.hpp" + +#include "hal.h" + +const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver() +{ + return this->_m_opResolver; +} + +bool arm::app::DsCnnModel::EnlistOperations() +{ + this->_m_opResolver.AddReshape(); + this->_m_opResolver.AddAveragePool2D(); + this->_m_opResolver.AddConv2D(); + this->_m_opResolver.AddDepthwiseConv2D(); + this->_m_opResolver.AddFullyConnected(); + this->_m_opResolver.AddRelu(); + this->_m_opResolver.AddSoftmax(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->_m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +extern uint8_t* GetModelPointer(); +const uint8_t* arm::app::DsCnnModel::ModelPointer() +{ + return GetModelPointer(); +} + +extern size_t GetModelLen(); +size_t arm::app::DsCnnModel::ModelSize() +{ + return GetModelLen(); +} \ No newline at end of file diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc new file mode 100644 index 0000000..24cb939 --- /dev/null +++ b/source/use_case/kws/src/MainLoop.cc @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "InputFiles.hpp" /* For input audio clips. */ +#include "Classifier.hpp" /* Classifier. */ +#include "DsCnnModel.hpp" /* Model class for running inference. */ +#include "hal.h" /* Brings in platform definitions. */ +#include "Labels.hpp" /* For label strings. */ +#include "UseCaseHandler.hpp" /* Handlers for different user options. */ +#include "UseCaseCommonUtils.hpp" /* Utils functions. */ + +using KwsClassifier = arm::app::Classifier; + +enum opcodes +{ + MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ + MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */ + MENU_OPT_RUN_INF_ALL, /* Run inference on all. */ + MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */ + MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio clips. */ +}; + +static void DisplayMenu() +{ + printf("\n\nUser input required\n"); + printf("Enter option number from:\n\n"); + printf(" %u. Classify next audio clip\n", MENU_OPT_RUN_INF_NEXT); + printf(" %u. Classify audio clip at chosen index\n", MENU_OPT_RUN_INF_CHOSEN); + printf(" %u. Run classification on all audio clips\n", MENU_OPT_RUN_INF_ALL); + printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO); + printf(" %u. List audio clips\n\n", MENU_OPT_LIST_AUDIO_CLIPS); + printf(" Choice: "); +} + +void main_loop(hal_platform& platform) +{ + arm::app::DsCnnModel model; /* Model wrapper object. */ + + /* Load the model. */ + if (!model.Init()) { + printf_err("Failed to initialise model\n"); + return; + } + + /* Instantiate application context. */ + arm::app::ApplicationContext caseContext; + + caseContext.Set("platform", platform); + caseContext.Set("model", model); + caseContext.Set("clipIndex", 0); + caseContext.Set("frameLength", g_FrameLength); + caseContext.Set("frameStride", g_FrameStride); + caseContext.Set("scoreThreshold", g_ScoreThreshold); /* Normalised score threshold. */ + + KwsClassifier classifier; /* classifier wrapper object. */ + caseContext.Set("classifier", classifier); + + std::vector labels; + GetLabelsVector(labels); + + caseContext.Set&>("labels", labels); + + bool executionSuccessful = true; + constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; + + /* Loop. */ + do { + int menuOption = MENU_OPT_RUN_INF_NEXT; + if (bUseMenu) { + DisplayMenu(); + menuOption = arm::app::ReadUserInputAsInt(platform); + printf("\n"); + } + switch (menuOption) { + case MENU_OPT_RUN_INF_NEXT: + executionSuccessful = ClassifyAudioHandler(caseContext, caseContext.Get("clipIndex"), false); + break; + case MENU_OPT_RUN_INF_CHOSEN: { + printf(" Enter the audio clip index [0, %d]: ", NUMBER_OF_FILES-1); + auto clipIndex = static_cast(arm::app::ReadUserInputAsInt(platform)); + executionSuccessful = ClassifyAudioHandler(caseContext, clipIndex, false); + break; + } + case MENU_OPT_RUN_INF_ALL: + executionSuccessful = ClassifyAudioHandler(caseContext,caseContext.Get("clipIndex"), true); + break; + case MENU_OPT_SHOW_MODEL_INFO: + executionSuccessful = model.ShowModelInfoHandler(); + break; + case MENU_OPT_LIST_AUDIO_CLIPS: + executionSuccessful = ListFilesHandler(caseContext); + break; + default: + printf("Incorrect choice, try again."); + break; + } + } while (executionSuccessful && bUseMenu); + info("Main loop terminated.\n"); +} \ No newline at end of file diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc new file mode 100644 index 0000000..872d323 --- /dev/null +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -0,0 +1,452 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseHandler.hpp" + +#include "InputFiles.hpp" +#include "Classifier.hpp" +#include "DsCnnModel.hpp" +#include "hal.h" +#include "DsCnnMfcc.hpp" +#include "AudioUtils.hpp" +#include "UseCaseCommonUtils.hpp" +#include "KwsResult.hpp" + +#include +#include + +using KwsClassifier = arm::app::Classifier; + +namespace arm { +namespace app { + + /** + * @brief Helper function to increment current audio clip index. + * @param[in,out] ctx Pointer to the application context object. + **/ + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx); + + /** + * @brief Helper function to set the audio clip index. + * @param[in,out] ctx Pointer to the application context object. + * @param[in] idx Value to be set. + * @return true if index is set, false otherwise. + **/ + static bool _SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx); + + /** + * @brief Presents inference results using the data presentation + * object. + * @param[in] platform Reference to the hal platform object. + * @param[in] results Vector of classification results to be displayed. + * @param[in] infTimeMs Inference time in milliseconds, if available, + * otherwise, this can be passed in as 0. + * @return true if successful, false otherwise. + **/ + static bool _PresentInferenceResult(hal_platform& platform, + const std::vector& results); + + /** + * @brief Returns a function to perform feature calculation and populates input tensor data with + * MFCC data. + * + * Input tensor data type check is performed to choose correct MFCC feature data type. + * If tensor has an integer data type then original features are quantised. + * + * Warning: MFCC calculator provided as input must have the same life scope as returned function. + * + * @param[in] mfcc MFCC feature calculator. + * @param[in,out] inputTensor Input tensor pointer to store calculated features. + * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). + * @return Function to be called providing audio sample and sliding window index. + */ + static std::function&, int, bool, size_t)> + GetFeatureCalculator(audio::DsCnnMFCC& mfcc, + TfLiteTensor* inputTensor, + size_t cacheSize); + + /* Audio inference handler. */ + bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) + { + auto& platform = ctx.Get("platform"); + + constexpr uint32_t dataPsnTxtInfStartX = 20; + constexpr uint32_t dataPsnTxtInfStartY = 40; + constexpr int minTensorDims = static_cast( + (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)? + arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx); + + platform.data_psn->clear(COLOR_BLACK); + + auto& model = ctx.Get("model"); + + /* If the request has a valid size, set the audio index. */ + if (clipIndex < NUMBER_OF_FILES) { + if (!_SetAppCtxClipIdx(ctx, clipIndex)) { + return false; + } + } + if (!model.IsInited()) { + printf_err("Model is not initialised! Terminating processing.\n"); + return false; + } + + const auto frameLength = ctx.Get("frameLength"); + const auto frameStride = ctx.Get("frameStride"); + const auto scoreThreshold = ctx.Get("scoreThreshold"); + auto startClipIdx = ctx.Get("clipIndex"); + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + TfLiteTensor* inputTensor = model.GetInputTensor(0); + + if (!inputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return false; + } else if (inputTensor->dims->size < minTensorDims) { + printf_err("Input tensor dimension should be >= %d\n", minTensorDims); + return false; + } + + TfLiteIntArray* inputShape = model.GetInputShape(0); + const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx]; + const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx]; + + audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength); + mfcc.Init(); + + /* Deduce the data length required for 1 inference from the network parameters. */ + auto audioDataWindowSize = kNumRows * frameStride + (frameLength - frameStride); + auto mfccWindowSize = frameLength; + auto mfccWindowStride = frameStride; + + /* We choose to move by half the window size => for a 1 second window size + * there is an overlap of 0.5 seconds. */ + auto audioDataStride = audioDataWindowSize / 2; + + /* To have the previously calculated features re-usable, stride must be multiple + * of MFCC features window stride. */ + if (0 != audioDataStride % mfccWindowStride) { + + /* Reduce the stride. */ + audioDataStride -= audioDataStride % mfccWindowStride; + } + + auto nMfccVectorsInAudioStride = audioDataStride/mfccWindowStride; + + /* We expect to be sampling 1 second worth of data at a time. + * NOTE: This is only used for time stamp calculation. */ + const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq; + + do { + auto currentIndex = ctx.Get("clipIndex"); + + /* Creating a mfcc features sliding window for the data required for 1 inference. */ + auto audioMFCCWindowSlider = audio::SlidingWindow( + get_audio_array(currentIndex), + audioDataWindowSize, mfccWindowSize, + mfccWindowStride); + + /* Creating a sliding window through the whole audio clip. */ + auto audioDataSlider = audio::SlidingWindow( + get_audio_array(currentIndex), + get_audio_array_size(currentIndex), + audioDataWindowSize, audioDataStride); + + /* Calculate number of the feature vectors in the window overlap region. + * These feature vectors will be reused.*/ + auto numberOfReusedFeatureVectors = audioMFCCWindowSlider.TotalStrides() + 1 + - nMfccVectorsInAudioStride; + + /* Construct feature calculation function. */ + auto mfccFeatureCalc = GetFeatureCalculator(mfcc, inputTensor, + numberOfReusedFeatureVectors); + + if (!mfccFeatureCalc){ + return false; + } + + /* Declare a container for results. */ + std::vector results; + + /* Display message on the LCD - inference running. */ + std::string str_inf{"Running inference... "}; + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + info("Running inference on audio clip %u => %s\n", currentIndex, + get_filename(currentIndex)); + + /* Start sliding through audio clip. */ + while (audioDataSlider.HasNext()) { + const int16_t *inferenceWindow = audioDataSlider.Next(); + + /* We moved to the next window - set the features sliding to the new address. */ + audioMFCCWindowSlider.Reset(inferenceWindow); + + /* The first window does not have cache ready. */ + bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; + + /* Start calculating features inside one audio sliding window. */ + while (audioMFCCWindowSlider.HasNext()) { + const int16_t *mfccWindow = audioMFCCWindowSlider.Next(); + std::vector mfccAudioData = std::vector(mfccWindow, + mfccWindow + mfccWindowSize); + /* Compute features for this window and write them to input tensor. */ + mfccFeatureCalc(mfccAudioData, + audioMFCCWindowSlider.Index(), + useCache, + nMfccVectorsInAudioStride); + } + + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); + + /* Run inference over this audio clip sliding window. */ + arm::app::RunInference(platform, model); + + std::vector classificationResult; + auto& classifier = ctx.Get("classifier"); + classifier.GetClassificationResults(outputTensor, classificationResult, + ctx.Get&>("labels"), 1); + + results.emplace_back(kws::KwsResult(classificationResult, + audioDataSlider.Index() * secondsPerSample * audioDataStride, + audioDataSlider.Index(), scoreThreshold)); + +#if VERIFY_TEST_OUTPUT + arm::app::DumpTensor(outputTensor); +#endif /* VERIFY_TEST_OUTPUT */ + } /* while (audioDataSlider.HasNext()) */ + + /* Erase. */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + + ctx.Set>("results", results); + + if (!_PresentInferenceResult(platform, results)) { + return false; + } + + _IncrementAppCtxClipIdx(ctx); + + } while (runAll && ctx.Get("clipIndex") != startClipIdx); + + return true; + } + + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx) + { + auto curAudioIdx = ctx.Get("clipIndex"); + + if (curAudioIdx + 1 >= NUMBER_OF_FILES) { + ctx.Set("clipIndex", 0); + return; + } + ++curAudioIdx; + ctx.Set("clipIndex", curAudioIdx); + } + + static bool _SetAppCtxClipIdx(ApplicationContext& ctx, const uint32_t idx) + { + if (idx >= NUMBER_OF_FILES) { + printf_err("Invalid idx %u (expected less than %u)\n", + idx, NUMBER_OF_FILES); + return false; + } + ctx.Set("clipIndex", idx); + return true; + } + + static bool _PresentInferenceResult(hal_platform& platform, + const std::vector& results) + { + constexpr uint32_t dataPsnTxtStartX1 = 20; + constexpr uint32_t dataPsnTxtStartY1 = 30; + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + + platform.data_psn->set_text_color(COLOR_GREEN); + + /* Display each result */ + uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; + + for (uint32_t i = 0; i < results.size(); ++i) { + + std::string topKeyword{""}; + float score = 0.f; + + if (results[i].m_resultVec.size()) { + topKeyword = results[i].m_resultVec[0].m_label; + score = results[i].m_resultVec[0].m_normalisedVal; + } + + std::string resultStr = + std::string{"@"} + std::to_string(results[i].m_timeStamp) + + std::string{"s: "} + topKeyword + std::string{" ("} + + std::to_string(static_cast(score * 100)) + std::string{"%)"}; + + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, false); + rowIdx1 += dataPsnTxtYIncr; + + info("For timestamp: %f (inference #: %u); threshold: %f\n", + results[i].m_timeStamp, results[i].m_inferenceNumber, + results[i].m_threshold); + for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) { + info("\t\tlabel @ %u: %s, score: %f\n", j, + results[i].m_resultVec[j].m_label.c_str(), + results[i].m_resultVec[j].m_normalisedVal); + } + } + + return true; + } + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T Feature vector type. + * @param inputTensor Model input tensor pointer. + * @param cacheSize Number of feature vectors to cache. Defined by the sliding window overlap. + * @param compute Features calculator function. + * @return Lambda function to compute features. + */ + template + std::function&, size_t, bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function (std::vector& )> compute) + { + /* Feature cache to be captured by lambda function. */ + static std::vector> featureCache = std::vector>(cacheSize); + + return [=](std::vector& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex) + { + T *tensorData = tflite::GetTensorData(inputTensor); + std::vector features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size(); + auto sizeBytes = sizeof(T) * size; + std::memcpy(tensorData + (index * size), features.data(), sizeBytes); + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex) { + featureCache[index - featuresOverlapIndex] = std::move(features); + } + }; + } + + template std::function&, size_t , bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector& )> compute); + + template std::function&, size_t , bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector& )> compute); + + template std::function&, size_t , bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector& )> compute); + + template std::function&, size_t, bool, size_t)> + _FeatureCalc(TfLiteTensor *inputTensor, + size_t cacheSize, + std::function(std::vector&)> compute); + + + static std::function&, int, bool, size_t)> + GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + { + std::function&, size_t, bool, size_t)> mfccFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + + auto *quantParams = (TfLiteAffineQuantization *) quant.params; + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + mfccFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &mfcc](std::vector& audioDataWindow) { + return mfcc.MfccComputeQuant(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + case kTfLiteUInt8: { + mfccFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &mfcc](std::vector& audioDataWindow) { + return mfcc.MfccComputeQuant(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + case kTfLiteInt16: { + mfccFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &mfcc](std::vector& audioDataWindow) { + return mfcc.MfccComputeQuant(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + + + } else { + mfccFeatureCalc = mfccFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [&mfcc](std::vector& audioDataWindow) { + return mfcc.MfccCompute(audioDataWindow); + }); + } + return mfccFeatureCalc; + } + +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws/usecase.cmake b/source/use_case/kws/usecase.cmake new file mode 100644 index 0000000..b5ac09e --- /dev/null +++ b/source/use_case/kws/usecase.cmake @@ -0,0 +1,159 @@ +#---------------------------------------------------------------------------- +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#---------------------------------------------------------------------------- + +# If the path to a directory or source file has been defined, +# get the type here (FILEPATH or PATH): +if (DEFINED ${use_case}_FILE_PATH) + get_path_type(${${use_case}_FILE_PATH} PATH_TYPE) + + # Set the default type if path is not a dir or file path (or undefined) + if (NOT ${PATH_TYPE} STREQUAL PATH AND NOT ${PATH_TYPE} STREQUAL FILEPATH) + message(FATAL_ERROR "Invalid ${use_case}_FILE_PATH. It should be a dir or file path.") + endif() +else() + # Default is a directory path + set(PATH_TYPE PATH) +endif() + +message(STATUS "${use_case}_FILE_PATH is of type: ${PATH_TYPE}") +USER_OPTION(${use_case}_FILE_PATH "Directory with custom WAV input files, or path to a single WAV file, to use in the evaluation application." + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/samples/ + ${PATH_TYPE}) + +USER_OPTION(${use_case}_LABELS_TXT_FILE "Labels' txt file for the chosen model." + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/ds_cnn_labels.txt + FILEPATH) + +USER_OPTION(${use_case}_AUDIO_RATE "Specify the target sampling rate. Default is 16000." + 16000 + STRING) + +USER_OPTION(${use_case}_AUDIO_MONO "Specify if the audio needs to be converted to mono. Default is ON." + ON + BOOL) + +USER_OPTION(${use_case}_AUDIO_OFFSET "Specify the offset to start reading after this time (in seconds). Default is 0." + 0 + STRING) + +USER_OPTION(${use_case}_AUDIO_DURATION "Specify the audio duration to load (in seconds). If set to 0 the entire audio will be processed." + 0 + STRING) + +USER_OPTION(${use_case}_AUDIO_RES_TYPE "Specify re-sampling algorithm to use. By default is 'kaiser_best'." + kaiser_best + STRING) + +USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples to use. By default is 16000, if the audio is shorter will be automatically padded." + 16000 + STRING) + +USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD "Specify the score threshold [0.0, 1.0) that must be applied to the inference results for a label to be deemed valid." + 0.9 + STRING) + +# Generate input files +generate_audio_code(${${use_case}_FILE_PATH} ${SRC_GEN_DIR} ${INC_GEN_DIR} + ${${use_case}_AUDIO_RATE} + ${${use_case}_AUDIO_MONO} + ${${use_case}_AUDIO_OFFSET} + ${${use_case}_AUDIO_DURATION} + ${${use_case}_AUDIO_RES_TYPE} + ${${use_case}_AUDIO_MIN_SAMPLES}) + +# Generate labels file +set(${use_case}_LABELS_CPP_FILE Labels) +generate_labels_code( + INPUT "${${use_case}_LABELS_TXT_FILE}" + DESTINATION_SRC ${SRC_GEN_DIR} + DESTINATION_HDR ${INC_GEN_DIR} + OUTPUT_FILENAME "${${use_case}_LABELS_CPP_FILE}" +) + +USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen model" + 0x00100000 + STRING) + +# If there is no tflite file pointed to +if (NOT DEFINED ${use_case}_MODEL_TFLITE_PATH) + + set(MODEL_FILENAME ds_cnn_clustered_int8.tflite) + set(MODEL_RESOURCES_DIR ${DOWNLOAD_DEP_DIR}/${use_case}) + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR}) + set(DEFAULT_MODEL_PATH ${MODEL_RESOURCES_DIR}/${MODEL_FILENAME}) + + # Download the default model + set(ZOO_COMMON_SUBPATH "models/keyword_spotting/ds_cnn_large/tflite_clustered_int8") + set(ZOO_MODEL_SUBPATH "${ZOO_COMMON_SUBPATH}/${MODEL_FILENAME}") + + download_file_from_modelzoo(${ZOO_MODEL_SUBPATH} ${DEFAULT_MODEL_PATH}) + + if (ETHOS_U55_ENABLED) + message(STATUS + "Ethos-U55 is enabled, but the model downloaded is not optimized by vela. " + "To use Ethos-U55 acceleration, optimise the downloaded model and pass it " + "as ${use_case}_MODEL_TFLITE_PATH to the CMake configuration.") + endif() + + # If the target platform is native + if (${TARGET_PLATFORM} STREQUAL native) + + # Download test vectors + set(ZOO_TEST_IFM_SUBPATH "${ZOO_COMMON_SUBPATH}/testing_input/input_2/0.npy") + set(ZOO_TEST_OFM_SUBPATH "${ZOO_COMMON_SUBPATH}/testing_output/Identity/0.npy") + + set(${use_case}_TEST_IFM ${MODEL_RESOURCES_DIR}/ifm0.npy CACHE FILEPATH + "Input test vector for ${use_case}") + set(${use_case}_TEST_OFM ${MODEL_RESOURCES_DIR}/ofm0.npy CACHE FILEPATH + "Input test vector for ${use_case}") + + download_file_from_modelzoo(${ZOO_TEST_IFM_SUBPATH} ${${use_case}_TEST_IFM}) + download_file_from_modelzoo(${ZOO_TEST_OFM_SUBPATH} ${${use_case}_TEST_OFM}) + + set(TEST_SRC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/src) + set(TEST_INC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/include) + file(MAKE_DIRECTORY ${TEST_SRC_GEN_DIR} ${TEST_INC_GEN_DIR}) + + # Generate test data files to be included in x86 tests + generate_test_data_code( + INPUT_DIR "${DOWNLOAD_DEP_DIR}/${use_case}" + DESTINATION_SRC ${TEST_SRC_GEN_DIR} + DESTINATION_HDR ${TEST_INC_GEN_DIR} + USECASE "${use_case}") + endif() + +else() + set(DEFAULT_MODEL_PATH "N/A") +endif() + +set(EXTRA_MODEL_CODE + "/* Model parameters for ${use_case} */" + "extern const int g_FrameLength = 640" + "extern const int g_FrameStride = 320" + "extern const float g_ScoreThreshold = ${${use_case}_MODEL_SCORE_THRESHOLD}" + ) + +USER_OPTION(${use_case}_MODEL_TFLITE_PATH "NN models file to be used in the evaluation application. Model files must be in tflite format." + ${DEFAULT_MODEL_PATH} + FILEPATH) + +# Generate model file +generate_tflite_code( + MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH} + DESTINATION ${SRC_GEN_DIR} + EXPRESSIONS ${EXTRA_MODEL_CODE} +) diff --git a/source/use_case/kws_asr/include/AsrClassifier.hpp b/source/use_case/kws_asr/include/AsrClassifier.hpp new file mode 100644 index 0000000..de18aa8 --- /dev/null +++ b/source/use_case/kws_asr/include/AsrClassifier.hpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_CLASSIFIER_HPP +#define ASR_CLASSIFIER_HPP + +#include "Classifier.hpp" + +namespace arm { +namespace app { + + class AsrClassifier : public Classifier { + public: + /** + * @brief Gets the top N classification results from the + * output vector. + * @param[in] outputTensor Inference output tensor from an NN model. + * @param[out] vecResults A vector of classification results + * populated by this function. + * @param[in] labels Labels vector to match classified classes + * @param[in] topNCount Number of top classifications to pick. + * @return true if successful, false otherwise. + **/ + bool GetClassificationResults( + TfLiteTensor* outputTensor, + std::vector& vecResults, + const std::vector & labels, uint32_t topNCount) override; + + private: + + /** + * @brief Utility function that gets the top 1 classification results from the + * output tensor (vector of vector). + * @param[in] tensor Inference output tensor from an NN model. + * @param[out] vecResults A vector of classification results + * populated by this function. + * @param[in] labels Labels vector to match classified classes. + * @param[in] scale Quantization scale. + * @param[in] zeroPoint Quantization zero point. + * @return true if successful, false otherwise. + **/ + template + bool _GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint); + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_CLASSIFIER_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/AsrResult.hpp b/source/use_case/kws_asr/include/AsrResult.hpp new file mode 100644 index 0000000..25fa9e8 --- /dev/null +++ b/source/use_case/kws_asr/include/AsrResult.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ASR_RESULT_HPP +#define ASR_RESULT_HPP + +#include "ClassificationResult.hpp" + +#include + +namespace arm { +namespace app { +namespace asr { + + using ResultVec = std::vector; + + /* Structure for holding asr result. */ + class AsrResult { + + public: + ResultVec m_resultVec; /* Container for "thresholded" classification results. */ + float m_timeStamp; /* Audio timestamp for this result. */ + uint32_t m_inferenceNumber; /* Corresponding inference number. */ + float m_threshold; /* Threshold value for `m_resultVec` */ + + AsrResult() = delete; + AsrResult(ResultVec& resultVec, + const float timestamp, + const uint32_t inferenceIdx, + const float scoreThreshold) { + + this->m_threshold = scoreThreshold; + this->m_timeStamp = timestamp; + this->m_inferenceNumber = inferenceIdx; + + this->m_resultVec = ResultVec(); + for (auto& i : resultVec) { + if (i.m_normalisedVal >= this->m_threshold) { + this->m_resultVec.emplace_back(i); + } + } + } + ~AsrResult() = default; + }; + +} /* namespace asr */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* ASR_RESULT_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/DsCnnMfcc.hpp b/source/use_case/kws_asr/include/DsCnnMfcc.hpp new file mode 100644 index 0000000..c97dd9d --- /dev/null +++ b/source/use_case/kws_asr/include/DsCnnMfcc.hpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_DSCNN_MFCC_HPP +#define KWS_ASR_DSCNN_MFCC_HPP + +#include "Mfcc.hpp" + +namespace arm { +namespace app { +namespace audio { + + /* Class to provide DS-CNN specific MFCC calculation requirements. */ + class DsCnnMFCC : public MFCC { + + public: + static constexpr uint32_t ms_defaultSamplingFreq = 16000; + static constexpr uint32_t ms_defaultNumFbankBins = 40; + static constexpr uint32_t ms_defaultMelLoFreq = 20; + static constexpr uint32_t ms_defaultMelHiFreq = 4000; + static constexpr bool ms_defaultUseHtkMethod = true; + + + explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen) + : MFCC(MfccParams( + ms_defaultSamplingFreq, ms_defaultNumFbankBins, + ms_defaultMelLoFreq, ms_defaultMelHiFreq, + numFeats, frameLen, ms_defaultUseHtkMethod)) + {} + DsCnnMFCC() = delete; + ~DsCnnMFCC() = default; + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_DSCNN_MFCC_HPP */ diff --git a/source/use_case/kws_asr/include/DsCnnModel.hpp b/source/use_case/kws_asr/include/DsCnnModel.hpp new file mode 100644 index 0000000..150a48c --- /dev/null +++ b/source/use_case/kws_asr/include/DsCnnModel.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_DSCNNMODEL_HPP +#define KWS_ASR_DSCNNMODEL_HPP + +#include "Model.hpp" + +namespace arm { +namespace app { +namespace kws { + extern const int g_FrameLength; + extern const int g_FrameStride; + extern const float g_ScoreThreshold; + extern const uint32_t g_NumMfcc; + extern const uint32_t g_NumAudioWins; +} /* namespace kws */ +} /* namespace app */ +} /* namespace arm */ + +namespace arm { +namespace app { + + class DsCnnModel : public Model { + public: + /* Indices for the expected model - based on input and output tensor shapes */ + static constexpr uint32_t ms_inputRowsIdx = 2; + static constexpr uint32_t ms_inputColsIdx = 3; + static constexpr uint32_t ms_outputRowsIdx = 2; + static constexpr uint32_t ms_outputColsIdx = 3; + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int _ms_maxOpCnt = 10; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver<_ms_maxOpCnt> _m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_DSCNNMODEL_HPP */ diff --git a/source/use_case/kws_asr/include/KwsResult.hpp b/source/use_case/kws_asr/include/KwsResult.hpp new file mode 100644 index 0000000..45bb790 --- /dev/null +++ b/source/use_case/kws_asr/include/KwsResult.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_RESULT_HPP +#define KWS_RESULT_HPP + +#include "ClassificationResult.hpp" + +#include + +namespace arm { +namespace app { +namespace kws { + + using ResultVec = std::vector < arm::app::ClassificationResult >; + + /* Structure for holding kws result. */ + class KwsResult { + + public: + ResultVec m_resultVec; /* Container for "thresholded" classification results. */ + float m_timeStamp; /* Audio timestamp for this result. */ + uint32_t m_inferenceNumber; /* Corresponding inference number. */ + float m_threshold; /* Threshold value for `m_resultVec.` */ + + KwsResult() = delete; + KwsResult(ResultVec& resultVec, + const float timestamp, + const uint32_t inferenceIdx, + const float scoreThreshold) { + + this->m_threshold = scoreThreshold; + this->m_timeStamp = timestamp; + this->m_inferenceNumber = inferenceIdx; + + this->m_resultVec = ResultVec(); + for (auto & i : resultVec) { + if (i.m_normalisedVal >= this->m_threshold) { + this->m_resultVec.emplace_back(i); + } + } + } + ~KwsResult() = default; + }; + +} /* namespace kws */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_RESULT_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/OutputDecode.hpp b/source/use_case/kws_asr/include/OutputDecode.hpp new file mode 100644 index 0000000..2bbb29c --- /dev/null +++ b/source/use_case/kws_asr/include/OutputDecode.hpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_OUTPUT_DECODE_HPP +#define KWS_ASR_OUTPUT_DECODE_HPP + +#include "AsrClassifier.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + /** + * @brief Gets the top N classification results from the + * output vector. + * @param[in] tensor Label output from classifier. + * @return true if successful, false otherwise. + **/ + std::string DecodeOutput(const std::vector& vecResults); + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_OUTPUT_DECODE_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/UseCaseHandler.hpp b/source/use_case/kws_asr/include/UseCaseHandler.hpp new file mode 100644 index 0000000..1c60662 --- /dev/null +++ b/source/use_case/kws_asr/include/UseCaseHandler.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_EVT_HANDLER_HPP +#define KWS_ASR_EVT_HANDLER_HPP + +#include "AppContext.hpp" + +namespace arm { +namespace app { + + /** + * @brief Handles the inference event. + * @param[in] ctx Pointer to the application context. + * @param[in] clipIndex Index to the audio clip to classify. + * @param[in] runAll Flag to request classification of all the available audio clips. + * @return true or false based on execution success. + **/ + bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll); + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_EVT_HANDLER_HPP */ diff --git a/source/use_case/kws_asr/include/Wav2LetterMfcc.hpp b/source/use_case/kws_asr/include/Wav2LetterMfcc.hpp new file mode 100644 index 0000000..0852cbf --- /dev/null +++ b/source/use_case/kws_asr/include/Wav2LetterMfcc.hpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_WAV2LET_MFCC_HPP +#define KWS_ASR_WAV2LET_MFCC_HPP + +#include "Mfcc.hpp" + +namespace arm { +namespace app { +namespace audio { + + /* Class to provide Wav2Letter specific MFCC calculation requirements. */ + class Wav2LetterMFCC : public MFCC { + + public: + static constexpr uint32_t ms_defaultSamplingFreq = 16000; + static constexpr uint32_t ms_defaultNumFbankBins = 128; + static constexpr uint32_t ms_defaultMelLoFreq = 0; + static constexpr uint32_t ms_defaultMelHiFreq = 8000; + static constexpr bool ms_defaultUseHtkMethod = false; + + explicit Wav2LetterMFCC(const size_t numFeats, const size_t frameLen) + : MFCC(MfccParams( + ms_defaultSamplingFreq, ms_defaultNumFbankBins, + ms_defaultMelLoFreq, ms_defaultMelHiFreq, + numFeats, frameLen, ms_defaultUseHtkMethod)) + {} + + Wav2LetterMFCC() = delete; + ~Wav2LetterMFCC() = default; + + protected: + + /** + * @brief Overrides base class implementation of this function. + * @param[in] fftVec Vector populated with FFT magnitudes. + * @param[in] melFilterBank 2D Vector with filter bank weights. + * @param[in] filterBankFilterFirst Vector containing the first indices of filter bank + * to be used for each bin. + * @param[in] filterBankFilterLast Vector containing the last indices of filter bank + * to be used for each bin. + * @param[out] melEnergies Pre-allocated vector of MEL energies to be + * populated. + * @return true if successful, false otherwise. + */ + bool ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) override; + + /** + * @brief Override for the base class implementation convert mel + * energies to logarithmic scale. The difference from + * default behaviour is that the power is converted to dB + * and subsequently clamped. + * @param[in,out] melEnergies 1D vector of Mel energies. + **/ + void ConvertToLogarithmicScale( + std::vector& melEnergies) override; + + /** + * @brief Create a matrix used to calculate Discrete Cosine + * Transform. Override for the base class' default + * implementation as the first and last elements + * use a different normaliser. + * @param[in] inputLength Input length of the buffer on which + * DCT will be performed. + * @param[in] coefficientCount Total coefficients per input length. + * @return 1D vector with inputLength x coefficientCount elements + * populated with DCT coefficients. + */ + std::vector CreateDCTMatrix( + int32_t inputLength, + int32_t coefficientCount) override; + + /** + * @brief Given the low and high Mel values, get the normaliser + * for weights to be applied when populating the filter + * bank. Override for the base class implementation. + * @param[in] leftMel Low Mel frequency value. + * @param[in] rightMel High Mel frequency value. + * @param[in] useHTKMethod Bool to signal if HTK method is to be + * used for calculation. + * @return Value to use for normalising. + */ + float GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + bool useHTKMethod) override; + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_WAV2LET_MFCC_HPP */ diff --git a/source/use_case/kws_asr/include/Wav2LetterModel.hpp b/source/use_case/kws_asr/include/Wav2LetterModel.hpp new file mode 100644 index 0000000..fb701ea --- /dev/null +++ b/source/use_case/kws_asr/include/Wav2LetterModel.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_WAV2LETTER_MODEL_HPP +#define KWS_ASR_WAV2LETTER_MODEL_HPP + +#include "Model.hpp" + +namespace arm { +namespace app { +namespace asr { + extern const int g_FrameLength; + extern const int g_FrameStride; + extern const float g_ScoreThreshold; + extern const int g_ctxLen; +} /* namespace asr */ +} /* namespace app */ +} /* namespace arm */ + +namespace arm { +namespace app { + + class Wav2LetterModel : public Model { + + public: + /* Indices for the expected model - based on input and output tensor shapes */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_outputRowsIdx = 2; + static constexpr uint32_t ms_outputColsIdx = 3; + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int _ms_maxOpCnt = 5; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver<_ms_maxOpCnt> _m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_WAV2LETTER_MODEL_HPP */ diff --git a/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp new file mode 100644 index 0000000..3a9d401 --- /dev/null +++ b/source/use_case/kws_asr/include/Wav2LetterPostprocess.hpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_WAV2LET_POSTPROC_HPP +#define KWS_ASR_WAV2LET_POSTPROC_HPP + +#include "TensorFlowLiteMicro.hpp" /* TensorFlow headers */ +#include "hal.h" /* stdout facility */ + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + /** + * @brief Helper class to manage tensor post-processing for "wav2letter" + * output. + */ + class Postprocess { + public: + /** + * @brief Constructor + * @param[in] contextLen Left and right context length for + * output tensor. + * @param[in] innerLen This is the length of the section + * between left and right context. + **/ + Postprocess(uint32_t contextLen, + uint32_t innerLen, + uint32_t blankTokenIdx); + + Postprocess() = delete; + ~Postprocess() = default; + + /** + * @brief Erases the required part of the tensor based + * on context lengths set up during initialisation + * @param[in] tensor Pointer to the tensor + * @param[in] axisIdx Index of the axis on which erase is + * performed. + * @param[in] lastIteration Flag to signal is this is the + * last iteration in which case + * the right context is preserved. + * @return true if successful, false otherwise. + */ + bool Invoke(TfLiteTensor* tensor, + uint32_t axisIdx, + bool lastIteration = false); + + private: + uint32_t _m_contextLen; /* Lengths of left and right contexts. */ + uint32_t _m_innerLen; /* Length of inner context. */ + uint32_t _m_totalLen; /* Total length of the required axis. */ + uint32_t _m_countIterations; /* Current number of iterations. */ + uint32_t _m_blankTokenIdx; /* Index of the labels blank token. */ + /** + * @brief Checks if the tensor and axis index are valid + * inputs to the object - based on how it has been + * initialised. + * @return true if valid, false otherwise. + */ + bool _IsInputValid(TfLiteTensor* tensor, + uint32_t axisIdx) const; + + /** + * @brief Gets the tensor data element size in bytes based + * on the tensor type. + * @return Size in bytes, 0 if not supported. + */ + uint32_t _GetTensorElementSize(TfLiteTensor* tensor); + + /** + * @brief Erases sections from the data assuming row-wise + * arrangement along the context axis. + * @return true if successful, false otherwise. + */ + bool _EraseSectionsRowWise(uint8_t* ptrData, + uint32_t strideSzBytes, + bool lastIteration); + + }; + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_WAV2LET_POSTPROC_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp new file mode 100644 index 0000000..3ffabb4 --- /dev/null +++ b/source/use_case/kws_asr/include/Wav2LetterPreprocess.hpp @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_WAV2LET_PREPROC_HPP +#define KWS_ASR_WAV2LET_PREPROC_HPP + +#include "Wav2LetterModel.hpp" +#include "Wav2LetterMfcc.hpp" +#include "AudioUtils.hpp" +#include "DataStructures.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + /* Class to facilitate pre-processing calculation for Wav2Letter model + * for ASR. */ + using AudioWindow = SlidingWindow ; + + class Preprocess { + public: + /** + * @brief Constructor + * @param[in] numMfccFeatures Number of MFCC features per window. + * @param[in] windowLen Number of elements in a window. + * @param[in] windowStride Stride (in number of elements) for + * moving the window. + * @param[in] numMfccVectors Number of MFCC vectors per window. + */ + Preprocess( + uint32_t numMfccFeatures, + uint32_t windowLen, + uint32_t windowStride, + uint32_t numMfccVectors); + Preprocess() = delete; + ~Preprocess() = default; + + /** + * @brief Calculates the features required from audio data. This + * includes MFCC, first and second order deltas, + * normalisation and finally, quantisation. The tensor is + * populated with feature from a given window placed along + * in a single row. + * @param[in] audioData Pointer to the first element of audio data. + * @param[in] audioDataLen Number of elements in the audio data. + * @param[in] tensor Tensor to be populated. + * @return true if successful, false in case of error. + */ + bool Invoke(const int16_t * audioData, + uint32_t audioDataLen, + TfLiteTensor * tensor); + + protected: + /** + * @brief Computes the first and second order deltas for the + * MFCC buffers - they are assumed to be populated. + * + * @param[in] mfcc MFCC buffers. + * @param[out] delta1 Result of the first diff computation. + * @param[out] delta2 Result of the second diff computation. + * + * @return true if successful, false otherwise. + */ + static bool _ComputeDeltas(Array2d& mfcc, + Array2d& delta1, + Array2d& delta2); + + /** + * @brief Given a 2D vector of floats, computes the mean. + * @param[in] vec Vector of vector of floats. + * @return Mean value. + */ + static float _GetMean(Array2d& vec); + + /** + * @brief Given a 2D vector of floats, computes the stddev. + * @param[in] vec Vector of vector of floats. + * @param[in] mean Mean value of the vector passed in. + * @return stddev value. + */ + static float _GetStdDev(Array2d& vec, + float mean); + + /** + * @brief Given a 2D vector of floats, normalises it using + * the mean and the stddev + * @param[in,out] vec Vector of vector of floats. + */ + static void _NormaliseVec(Array2d& vec); + + /** + * @brief Normalises the MFCC and delta buffers. + */ + void _Normalise(); + + /** + * @brief Given the quantisation and data type limits, computes + * the quantised values of a floating point input data. + * @param[in] elem Element to be quantised. + * @param[in] quantScale Scale. + * @param[in] quantOffset Offset. + * @param[in] minVal Numerical limit - minimum. + * @param[in] maxVal Numerical limit - maximum. + * @return Floating point quantised value. + */ + static float _GetQuantElem( + float elem, + float quantScale, + int quantOffset, + float minVal, + float maxVal); + + /** + * @brief Quantises the MFCC and delta buffers, and places them + * in the output buffer. While doing so, it transposes + * the data. Reason: Buffers in this class are arranged + * for "time" axis to be row major. Primary reason for + * this being the convolution speed up (as we can use + * contiguous memory). The output, however, requires the + * time axis to be in column major arrangement. + * @param[in] outputBuf Pointer to the output buffer. + * @param[in] outputBufSz Output buffer's size. + * @param[in] quantScale Quantisation scale. + * @param[in] quantOffset Quantisation offset. + */ + template + bool _Quantise( + T * outputBuf, + const uint32_t outputBufSz, + const float quantScale, + const int quantOffset) + { + /* Check the output size will for everything. */ + if (outputBufSz < (this->_m_mfccBuf.size(0) * 3 * sizeof(T))) { + printf_err("Tensor size too small for features\n"); + return false; + } + + /* Populate. */ + T * outputBufMfcc = outputBuf; + T * outputBufD1 = outputBuf + this->_m_numMfccFeats; + T * outputBufD2 = outputBufD1 + this->_m_numMfccFeats; + const uint32_t ptrIncr = this->_m_numMfccFeats * 2; /* (3 vectors - 1 vector) */ + + const float minVal = std::numeric_limits::min(); + const float maxVal = std::numeric_limits::max(); + + /* We need to do a transpose while copying and concatenating + * the tensor. */ + for (uint32_t j = 0; j < this->_m_numFeatVectors; ++j) { + for (uint32_t i = 0; i < this->_m_numMfccFeats; ++i) { + *outputBufMfcc++ = static_cast(this->_GetQuantElem( + this->_m_mfccBuf(i, j), quantScale, + quantOffset, minVal, maxVal)); + *outputBufD1++ = static_cast(this->_GetQuantElem( + this->_m_delta1Buf(i, j), quantScale, + quantOffset, minVal, maxVal)); + *outputBufD2++ = static_cast(this->_GetQuantElem( + this->_m_delta2Buf(i, j), quantScale, + quantOffset, minVal, maxVal)); + } + outputBufMfcc += ptrIncr; + outputBufD1 += ptrIncr; + outputBufD2 += ptrIncr; + } + + return true; + } + + private: + Wav2LetterMFCC _m_mfcc; /* MFCC instance. */ + + /* Actual buffers to be populated. */ + Array2d _m_mfccBuf; /* Contiguous buffer 1D: MFCC */ + Array2d _m_delta1Buf; /* Contiguous buffer 1D: Delta 1 */ + Array2d _m_delta2Buf; /* Contiguous buffer 1D: Delta 2 */ + + uint32_t _m_windowLen; /* Window length for MFCC. */ + uint32_t _m_windowStride; /* Window stride len for MFCC. */ + uint32_t _m_numMfccFeats; /* Number of MFCC features per window. */ + uint32_t _m_numFeatVectors; /* Number of _m_numMfccFeats. */ + AudioWindow _m_window; /* Sliding window. */ + + }; + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_WAV2LET_PREPROC_HPP */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/AsrClassifier.cc b/source/use_case/kws_asr/src/AsrClassifier.cc new file mode 100644 index 0000000..bc86e09 --- /dev/null +++ b/source/use_case/kws_asr/src/AsrClassifier.cc @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AsrClassifier.hpp" + +#include "hal.h" +#include "TensorFlowLiteMicro.hpp" +#include "Wav2LetterModel.hpp" + +template +bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint) +{ + const uint32_t nElems = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputRowsIdx]; + const uint32_t nLetters = tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]; + + + /* NOTE: tensor's size verification against labels should be + * checked by the calling/public function. */ + if (nLetters < 1) { + return false; + } + + /* Final results' container. */ + vecResults = std::vector(nElems); + + T* tensorData = tflite::GetTensorData(tensor); + + /* Get the top 1 results. */ + for (uint32_t i = 0, row = 0; i < nElems; ++i, row+=nLetters) { + std::pair top_1 = std::make_pair(tensorData[row + 0], 0); + + for (uint32_t j = 1; j < nLetters; ++j) { + if (top_1.first < tensorData[row + j]) { + top_1.first = tensorData[row + j]; + top_1.second = j; + } + } + + double score = static_cast (top_1.first); + vecResults[i].m_normalisedVal = scale * (score - zeroPoint); + vecResults[i].m_label = labels[top_1.second]; + vecResults[i].m_labelIdx = top_1.second; + } + + return true; +} +template bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint); +template bool arm::app::AsrClassifier::_GetTopResults(TfLiteTensor* tensor, + std::vector& vecResults, + const std::vector & labels, double scale, double zeroPoint); + +bool arm::app::AsrClassifier::GetClassificationResults( + TfLiteTensor* outputTensor, + std::vector& vecResults, + const std::vector & labels, uint32_t topNCount) +{ + vecResults.clear(); + + constexpr int minTensorDims = static_cast( + (arm::app::Wav2LetterModel::ms_outputRowsIdx > arm::app::Wav2LetterModel::ms_outputColsIdx)? + arm::app::Wav2LetterModel::ms_outputRowsIdx : arm::app::Wav2LetterModel::ms_outputColsIdx); + + constexpr uint32_t outColsIdx = arm::app::Wav2LetterModel::ms_outputColsIdx; + + /* Sanity checks. */ + if (outputTensor == nullptr) { + printf_err("Output vector is null pointer.\n"); + return false; + } else if (outputTensor->dims->size < minTensorDims) { + printf_err("Output tensor expected to be 3D (1, m, n)\n"); + return false; + } else if (static_cast(outputTensor->dims->data[outColsIdx]) < topNCount) { + printf_err("Output vectors are smaller than %u\n", topNCount); + return false; + } else if (static_cast(outputTensor->dims->data[outColsIdx]) != labels.size()) { + printf("Output size doesn't match the labels' size\n"); + return false; + } + + if (topNCount != 1) { + warn("TopNCount value ignored in this implementation\n"); + } + + /* To return the floating point values, we need quantization parameters. */ + QuantParams quantParams = GetTensorQuantParams(outputTensor); + + bool resultState; + + switch (outputTensor->type) { + case kTfLiteUInt8: + resultState = this->_GetTopResults( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + case kTfLiteInt8: + resultState = this->_GetTopResults( + outputTensor, vecResults, + labels, quantParams.scale, + quantParams.offset); + break; + default: + printf_err("Tensor type %s not supported by classifier\n", + TfLiteTypeGetName(outputTensor->type)); + return false; + } + + if (!resultState) { + printf_err("Failed to get sorted set\n"); + return false; + } + + return true; +} \ No newline at end of file diff --git a/source/use_case/kws_asr/src/DsCnnModel.cc b/source/use_case/kws_asr/src/DsCnnModel.cc new file mode 100644 index 0000000..b573a12 --- /dev/null +++ b/source/use_case/kws_asr/src/DsCnnModel.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "DsCnnModel.hpp" + +#include "hal.h" + +namespace arm { +namespace app { +namespace kws { + extern uint8_t* GetModelPointer(); + extern size_t GetModelLen(); +} /* namespace kws */ +} /* namespace app */ +} /* namespace arm */ + +const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver() +{ + return this->_m_opResolver; +} + +bool arm::app::DsCnnModel::EnlistOperations() +{ + this->_m_opResolver.AddAveragePool2D(); + this->_m_opResolver.AddConv2D(); + this->_m_opResolver.AddDepthwiseConv2D(); + this->_m_opResolver.AddFullyConnected(); + this->_m_opResolver.AddRelu(); + this->_m_opResolver.AddSoftmax(); + this->_m_opResolver.AddQuantize(); + this->_m_opResolver.AddDequantize(); + this->_m_opResolver.AddReshape(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->_m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +const uint8_t* arm::app::DsCnnModel::ModelPointer() +{ + return arm::app::kws::GetModelPointer(); +} + +size_t arm::app::DsCnnModel::ModelSize() +{ + return arm::app::kws::GetModelLen(); +} \ No newline at end of file diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc new file mode 100644 index 0000000..37146c9 --- /dev/null +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -0,0 +1,233 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "hal.h" /* Brings in platform definitions. */ +#include "InputFiles.hpp" /* For input images. */ +#include "Labels_dscnn.hpp" /* For DS-CNN label strings. */ +#include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ +#include "Classifier.hpp" /* KWS classifier. */ +#include "AsrClassifier.hpp" /* ASR classifier. */ +#include "DsCnnModel.hpp" /* KWS model class for running inference. */ +#include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ +#include "UseCaseCommonUtils.hpp" /* Utils functions. */ +#include "UseCaseHandler.hpp" /* Handlers for different user options. */ +#include "Wav2LetterPreprocess.hpp" /* ASR pre-processing class. */ +#include "Wav2LetterPostprocess.hpp"/* ASR post-processing class. */ + +using KwsClassifier = arm::app::Classifier; + +enum opcodes +{ + MENU_OPT_RUN_INF_NEXT = 1, /* Run on next vector. */ + MENU_OPT_RUN_INF_CHOSEN, /* Run on a user provided vector index. */ + MENU_OPT_RUN_INF_ALL, /* Run inference on all. */ + MENU_OPT_SHOW_MODEL_INFO, /* Show model info. */ + MENU_OPT_LIST_AUDIO_CLIPS /* List the current baked audio clips. */ +}; + +static void DisplayMenu() +{ + printf("\n\nUser input required\n"); + printf("Enter option number from:\n\n"); + printf(" %u. Classify next audio clip\n", MENU_OPT_RUN_INF_NEXT); + printf(" %u. Classify audio clip at chosen index\n", MENU_OPT_RUN_INF_CHOSEN); + printf(" %u. Run classification on all audio clips\n", MENU_OPT_RUN_INF_ALL); + printf(" %u. Show NN model info\n", MENU_OPT_SHOW_MODEL_INFO); + printf(" %u. List audio clips\n\n", MENU_OPT_LIST_AUDIO_CLIPS); + printf(" Choice: "); +} + +/** @brief Gets the number of MFCC features for a single window. */ +static uint32_t GetNumMfccFeatures(const arm::app::Model& model); + +/** @brief Gets the number of MFCC feature vectors to be computed. */ +static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model); + +/** @brief Gets the output context length (left and right) for post-processing. */ +static uint32_t GetOutputContextLen(const arm::app::Model& model, + uint32_t inputCtxLen); + +/** @brief Gets the output inner length for post-processing. */ +static uint32_t GetOutputInnerLen(const arm::app::Model& model, + uint32_t outputCtxLen); + +void main_loop(hal_platform& platform) +{ + /* Model wrapper objects. */ + arm::app::DsCnnModel kwsModel; + arm::app::Wav2LetterModel asrModel; + + /* Load the models. */ + if (!kwsModel.Init()) { + printf_err("Failed to initialise KWS model\n"); + return; + } + + /* Initialise the asr model using the same allocator from KWS + * to re-use the tensor arena. */ + if (!asrModel.Init(kwsModel.GetAllocator())) { + printf_err("Failed to initalise ASR model\n"); + return; + } + + /* Initialise ASR pre-processing. */ + arm::app::audio::asr::Preprocess prep( + GetNumMfccFeatures(asrModel), + arm::app::asr::g_FrameLength, + arm::app::asr::g_FrameStride, + GetNumMfccFeatureVectors(asrModel)); + + /* Initialise ASR post-processing. */ + const uint32_t outputCtxLen = GetOutputContextLen(asrModel, arm::app::asr::g_ctxLen); + const uint32_t blankTokenIdx = 28; + arm::app::audio::asr::Postprocess postp( + outputCtxLen, + GetOutputInnerLen(asrModel, outputCtxLen), + blankTokenIdx); + + /* Instantiate application context. */ + arm::app::ApplicationContext caseContext; + + caseContext.Set("platform", platform); + caseContext.Set("kwsmodel", kwsModel); + caseContext.Set("asrmodel", asrModel); + caseContext.Set("clipIndex", 0); + caseContext.Set("ctxLen", arm::app::asr::g_ctxLen); /* Left and right context length (MFCC feat vectors). */ + caseContext.Set("kwsframeLength", arm::app::kws::g_FrameLength); + caseContext.Set("kwsframeStride", arm::app::kws::g_FrameStride); + caseContext.Set("kwsscoreThreshold", arm::app::kws::g_ScoreThreshold); /* Normalised score threshold. */ + caseContext.Set("kwsNumMfcc", arm::app::kws::g_NumMfcc); + caseContext.Set("kwsNumAudioWins", arm::app::kws::g_NumAudioWins); + + caseContext.Set("asrframeLength", arm::app::asr::g_FrameLength); + caseContext.Set("asrframeStride", arm::app::asr::g_FrameStride); + caseContext.Set("asrscoreThreshold", arm::app::asr::g_ScoreThreshold); /* Normalised score threshold. */ + + KwsClassifier kwsClassifier; /* Classifier wrapper object. */ + arm::app::AsrClassifier asrClassifier; /* Classifier wrapper object. */ + caseContext.Set("kwsclassifier", kwsClassifier); + caseContext.Set("asrclassifier", asrClassifier); + + caseContext.Set("preprocess", prep); + caseContext.Set("postprocess", postp); + + std::vector asrLabels; + arm::app::asr::GetLabelsVector(asrLabels); + std::vector kwsLabels; + arm::app::kws::GetLabelsVector(kwsLabels); + caseContext.Set&>("asrlabels", asrLabels); + caseContext.Set&>("kwslabels", kwsLabels); + + /* Index of the kws outputs we trigger ASR on. */ + caseContext.Set("keywordindex", 2); + + /* Loop. */ + bool executionSuccessful = true; + constexpr bool bUseMenu = NUMBER_OF_FILES > 1 ? true : false; + + /* Loop. */ + do { + int menuOption = MENU_OPT_RUN_INF_NEXT; + if (bUseMenu) { + DisplayMenu(); + menuOption = arm::app::ReadUserInputAsInt(platform); + printf("\n"); + } + switch (menuOption) { + case MENU_OPT_RUN_INF_NEXT: + executionSuccessful = ClassifyAudioHandler( + caseContext, + caseContext.Get("clipIndex"), + false); + break; + case MENU_OPT_RUN_INF_CHOSEN: { + printf(" Enter the audio clip index [0, %d]: ", + NUMBER_OF_FILES-1); + auto clipIndex = static_cast( + arm::app::ReadUserInputAsInt(platform)); + executionSuccessful = ClassifyAudioHandler(caseContext, + clipIndex, + false); + break; + } + case MENU_OPT_RUN_INF_ALL: + executionSuccessful = ClassifyAudioHandler( + caseContext, + caseContext.Get("clipIndex"), + true); + break; + case MENU_OPT_SHOW_MODEL_INFO: + executionSuccessful = kwsModel.ShowModelInfoHandler(); + executionSuccessful = asrModel.ShowModelInfoHandler(); + break; + case MENU_OPT_LIST_AUDIO_CLIPS: + executionSuccessful = ListFilesHandler(caseContext); + break; + default: + printf("Incorrect choice, try again."); + break; + } + } while (executionSuccessful && bUseMenu); + info("Main loop terminated.\n"); +} + +static uint32_t GetNumMfccFeatures(const arm::app::Model& model) +{ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputCols = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputColsIdx]; + if (0 != inputCols % 3) { + printf_err("Number of input columns is not a multiple of 3\n"); + } + return std::max(inputCols/3, 0); +} + +static uint32_t GetNumMfccFeatureVectors(const arm::app::Model& model) +{ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + const int inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; + return std::max(inputRows, 0); +} + +static uint32_t GetOutputContextLen(const arm::app::Model& model, const uint32_t inputCtxLen) +{ + const uint32_t inputRows = GetNumMfccFeatureVectors(model); + const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen); + constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; + + /* Check to make sure that the input tensor supports the above context and inner lengths. */ + if (inputRows <= 2 * inputCtxLen || inputRows <= inputInnerLen) { + printf_err("Input rows not compatible with ctx of %u\n", + inputCtxLen); + return 0; + } + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + + const float tensorColRatio = static_cast(inputRows)/ + static_cast(outputRows); + + return std::round(static_cast(inputCtxLen)/tensorColRatio); +} + +static uint32_t GetOutputInnerLen(const arm::app::Model& model, + const uint32_t outputCtxLen) +{ + constexpr uint32_t ms_outputRowsIdx = arm::app::Wav2LetterModel::ms_outputRowsIdx; + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + const uint32_t outputRows = std::max(outputTensor->dims->data[ms_outputRowsIdx], 0); + return (outputRows - (2 * outputCtxLen)); +} diff --git a/source/use_case/kws_asr/src/OutputDecode.cc b/source/use_case/kws_asr/src/OutputDecode.cc new file mode 100644 index 0000000..41fbe07 --- /dev/null +++ b/source/use_case/kws_asr/src/OutputDecode.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "OutputDecode.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + std::string DecodeOutput(const std::vector& vecResults) + { + std::string CleanOutputBuffer; + + for (size_t i = 0; i < vecResults.size(); ++i) /* For all elements in vector. */ + { + while (i+1 < vecResults.size() && + vecResults[i].m_label == vecResults[i+1].m_label) /* While the current element is equal to the next, ignore it and move on. */ + { + ++i; + } + if (vecResults[i].m_label != "$") /* $ is a character used to represent unknown and double characters so should not be in output. */ + { + CleanOutputBuffer += vecResults[i].m_label; /* If the element is different to the next, it will be appended to CleanOutputBuffer. */ + } + } + + return CleanOutputBuffer; /* Return string type containing clean output. */ + } + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc new file mode 100644 index 0000000..c50796f --- /dev/null +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -0,0 +1,707 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "UseCaseHandler.hpp" + +#include "hal.h" +#include "InputFiles.hpp" +#include "AudioUtils.hpp" +#include "UseCaseCommonUtils.hpp" +#include "DsCnnModel.hpp" +#include "DsCnnMfcc.hpp" +#include "Classifier.hpp" +#include "KwsResult.hpp" +#include "Wav2LetterMfcc.hpp" +#include "Wav2LetterPreprocess.hpp" +#include "Wav2LetterPostprocess.hpp" +#include "AsrResult.hpp" +#include "AsrClassifier.hpp" +#include "OutputDecode.hpp" + + +using KwsClassifier = arm::app::Classifier; + +namespace arm { +namespace app { + + enum AsrOutputReductionAxis { + AxisRow = 1, + AxisCol = 2 + }; + + struct KWSOutput { + bool executionSuccess = false; + const int16_t* asrAudioStart = nullptr; + int32_t asrAudioSamples = 0; + }; + + /** + * @brief Helper function to increment current audio clip index + * @param[in,out] ctx pointer to the application context object + **/ + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx); + + /** + * @brief Helper function to increment current audio clip index + * @param[in,out] ctx pointer to the application context object + **/ + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx); + + /** + * @brief Helper function to set the audio clip index + * @param[in,out] ctx pointer to the application context object + * @param[in] idx value to be set + * @return true if index is set, false otherwise + **/ + static bool _SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx); + + /** + * @brief Presents kws inference results using the data presentation + * object. + * @param[in] platform reference to the hal platform object + * @param[in] results vector of classification results to be displayed + * @param[in] infTimeMs inference time in milliseconds, if available + * Otherwise, this can be passed in as 0. + * @return true if successful, false otherwise + **/ + static bool _PresentInferenceResult(hal_platform& platform, std::vector& results); + + /** + * @brief Presents asr inference results using the data presentation + * object. + * @param[in] platform reference to the hal platform object + * @param[in] results vector of classification results to be displayed + * @param[in] infTimeMs inference time in milliseconds, if available + * Otherwise, this can be passed in as 0. + * @return true if successful, false otherwise + **/ + static bool _PresentInferenceResult(hal_platform& platform, std::vector& results); + + /** + * @brief Returns a function to perform feature calculation and populates input tensor data with + * MFCC data. + * + * Input tensor data type check is performed to choose correct MFCC feature data type. + * If tensor has an integer data type then original features are quantised. + * + * Warning: mfcc calculator provided as input must have the same life scope as returned function. + * + * @param[in] mfcc MFCC feature calculator. + * @param[in,out] inputTensor Input tensor pointer to store calculated features. + * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors). + * + * @return function function to be called providing audio sample and sliding window index. + **/ + static std::function&, int, bool, size_t)> + GetFeatureCalculator(audio::DsCnnMFCC& mfcc, + TfLiteTensor* inputTensor, + size_t cacheSize); + + /** + * @brief Performs the KWS pipeline. + * @param[in,out] ctx pointer to the application context object + * + * @return KWSOutput struct containing pointer to audio data where ASR should begin + * and how much data to process. + */ + static KWSOutput doKws(ApplicationContext& ctx) { + constexpr uint32_t dataPsnTxtInfStartX = 20; + constexpr uint32_t dataPsnTxtInfStartY = 40; + + constexpr int minTensorDims = static_cast( + (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)? + arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx); + + KWSOutput output; + + auto& kwsModel = ctx.Get("kwsmodel"); + if (!kwsModel.IsInited()) { + printf_err("KWS model has not been initialised\n"); + return output; + } + + const int kwsFrameLength = ctx.Get("kwsframeLength"); + const int kwsFrameStride = ctx.Get("kwsframeStride"); + const float kwsScoreThreshold = ctx.Get("kwsscoreThreshold"); + + TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0); + TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0); + + if (!kwsInputTensor->dims) { + printf_err("Invalid input tensor dims\n"); + return output; + } else if (kwsInputTensor->dims->size < minTensorDims) { + printf_err("Input tensor dimension should be >= %d\n", minTensorDims); + return output; + } + + const uint32_t kwsNumMfccFeats = ctx.Get("kwsNumMfcc"); + const uint32_t kwsNumAudioWindows = ctx.Get("kwsNumAudioWins"); + + audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength); + kwsMfcc.Init(); + + /* Deduce the data length required for 1 KWS inference from the network parameters. */ + auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride + + (kwsFrameLength - kwsFrameStride); + auto kwsMfccWindowSize = kwsFrameLength; + auto kwsMfccWindowStride = kwsFrameStride; + + /* We are choosing to move by half the window size => for a 1 second window size, + * this means an overlap of 0.5 seconds. */ + auto kwsAudioDataStride = kwsAudioDataWindowSize / 2; + + info("KWS audio data window size %u\n", kwsAudioDataWindowSize); + + /* Stride must be multiple of mfcc features window stride to re-use features. */ + if (0 != kwsAudioDataStride % kwsMfccWindowStride) { + kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride; + } + + auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride; + + /* We expect to be sampling 1 second worth of data at a time + * NOTE: This is only used for time stamp calculation. */ + const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq; + + auto currentIndex = ctx.Get("clipIndex"); + + /* Creating a mfcc features sliding window for the data required for 1 inference. */ + auto kwsAudioMFCCWindowSlider = audio::SlidingWindow( + get_audio_array(currentIndex), + kwsAudioDataWindowSize, kwsMfccWindowSize, + kwsMfccWindowStride); + + /* Creating a sliding window through the whole audio clip. */ + auto audioDataSlider = audio::SlidingWindow( + get_audio_array(currentIndex), + get_audio_array_size(currentIndex), + kwsAudioDataWindowSize, kwsAudioDataStride); + + /* Calculate number of the feature vectors in the window overlap region. + * These feature vectors will be reused.*/ + size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1 + - kwsMfccVectorsInAudioStride; + + auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor, + numberOfReusedFeatureVectors); + + if (!kwsMfccFeatureCalc){ + return output; + } + + /* Container for KWS results. */ + std::vector kwsResults; + + /* Display message on the LCD - inference running. */ + auto& platform = ctx.Get("platform"); + std::string str_inf{"Running KWS inference... "}; + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + info("Running KWS inference on audio clip %u => %s\n", + currentIndex, get_filename(currentIndex)); + + /* Start sliding through audio clip. */ + while (audioDataSlider.HasNext()) { + const int16_t* inferenceWindow = audioDataSlider.Next(); + + /* We moved to the next window - set the features sliding to the new address. */ + kwsAudioMFCCWindowSlider.Reset(inferenceWindow); + + /* The first window does not have cache ready. */ + bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0; + + /* Start calculating features inside one audio sliding window. */ + while (kwsAudioMFCCWindowSlider.HasNext()) { + const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next(); + std::vector kwsMfccAudioData = + std::vector(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize); + + /* Compute features for this window and write them to input tensor. */ + kwsMfccFeatureCalc(kwsMfccAudioData, + kwsAudioMFCCWindowSlider.Index(), + useCache, + kwsMfccVectorsInAudioStride); + } + + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + audioDataSlider.TotalStrides() + 1); + + /* Run inference over this audio clip sliding window. */ + arm::app::RunInference(platform, kwsModel); + + std::vector kwsClassificationResult; + auto& kwsClassifier = ctx.Get("kwsclassifier"); + + kwsClassifier.GetClassificationResults( + kwsOutputTensor, kwsClassificationResult, + ctx.Get&>("kwslabels"), 1); + + kwsResults.emplace_back( + kws::KwsResult( + kwsClassificationResult, + audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride, + audioDataSlider.Index(), kwsScoreThreshold) + ); + + /* Keyword detected. */ + if (kwsClassificationResult[0].m_labelIdx == ctx.Get("keywordindex")) { + output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize; + output.asrAudioSamples = get_audio_array_size(currentIndex) - + (audioDataSlider.NextWindowStartIndex() - + kwsAudioDataStride + kwsAudioDataWindowSize); + break; + } + +#if VERIFY_TEST_OUTPUT + arm::app::DumpTensor(kwsOutputTensor); +#endif /* VERIFY_TEST_OUTPUT */ + + } /* while (audioDataSlider.HasNext()) */ + + /* Erase. */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + if (!_PresentInferenceResult(platform, kwsResults)) { + return output; + } + + output.executionSuccess = true; + return output; + } + + /** + * @brief Performs the ASR pipeline. + * + * @param ctx[in/out] pointer to the application context object + * @param kwsOutput[in] struct containing pointer to audio data where ASR should begin + * and how much data to process + * @return bool true if pipeline executed without failure + */ + static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) { + constexpr uint32_t dataPsnTxtInfStartX = 20; + constexpr uint32_t dataPsnTxtInfStartY = 40; + + auto& platform = ctx.Get("platform"); + platform.data_psn->clear(COLOR_BLACK); + + /* Get model reference. */ + auto& asrModel = ctx.Get("asrmodel"); + if (!asrModel.IsInited()) { + printf_err("ASR model has not been initialised\n"); + return false; + } + + /* Get score threshold to be applied for the classifier (post-inference). */ + auto asrScoreThreshold = ctx.Get("asrscoreThreshold"); + + /* Dimensions of the tensor should have been verified by the callee. */ + TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0); + TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0); + const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx]; + + /* Populate ASR MFCC related parameters. */ + auto asrMfccParamsWinLen = ctx.Get("asrframeLength"); + auto asrMfccParamsWinStride = ctx.Get("asrframeStride"); + + /* Populate ASR inference context and inner lengths for input. */ + auto asrInputCtxLen = ctx.Get("ctxLen"); + const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen); + + /* Make sure the input tensor supports the above context and inner lengths. */ + if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) { + printf_err("ASR input rows not compatible with ctx length %u\n", asrInputCtxLen); + return false; + } + + /* Audio data stride corresponds to inputInnerLen feature vectors. */ + const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) * + asrMfccParamsWinStride + (asrMfccParamsWinLen); + const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride; + const float asrAudioParamsSecondsPerSample = + (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq); + + /* Get pre/post-processing objects */ + auto& asrPrep = ctx.Get("preprocess"); + auto& asrPostp = ctx.Get("postprocess"); + + /* Set default reduction axis for post-processing. */ + const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx; + + /* Get the remaining audio buffer and respective size from KWS results. */ + const int16_t* audioArr = kwsOutput.asrAudioStart; + const uint32_t audioArrSize = kwsOutput.asrAudioSamples; + + /* Audio clip must have enough samples to produce 1 MFCC feature. */ + std::vector audioBuffer = std::vector(audioArr, audioArr + audioArrSize); + if (audioArrSize < asrMfccParamsWinLen) { + printf_err("Not enough audio samples, minimum needed is %u\n", asrMfccParamsWinLen); + return false; + } + + /* Initialise an audio slider. */ + auto audioDataSlider = audio::ASRSlidingWindow( + audioBuffer.data(), + audioBuffer.size(), + asrAudioParamsWinLen, + asrAudioParamsWinStride); + + /* Declare a container for results. */ + std::vector asrResults; + + /* Display message on the LCD - inference running. */ + std::string str_inf{"Running ASR inference... "}; + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0); + + size_t asrInferenceWindowLen = asrAudioParamsWinLen; + + /* Start sliding through audio clip. */ + while (audioDataSlider.HasNext()) { + + /* If not enough audio see how much can be sent for processing. */ + size_t nextStartIndex = audioDataSlider.NextWindowStartIndex(); + if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) { + asrInferenceWindowLen = audioBuffer.size() - nextStartIndex; + } + + const int16_t* asrInferenceWindow = audioDataSlider.Next(); + + info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, + static_cast(ceilf(audioDataSlider.FractionalTotalStrides() + 1))); + + Profiler prepProfiler{&platform, "pre-processing"}; + prepProfiler.StartProfiling(); + + /* Calculate MFCCs, deltas and populate the input tensor. */ + asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor); + + prepProfiler.StopProfiling(); + std::string prepProfileResults = prepProfiler.GetResultsAndReset(); + info("%s\n", prepProfileResults.c_str()); + + /* Run inference over this audio clip sliding window. */ + arm::app::RunInference(platform, asrModel); + + /* Post-process. */ + asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext()); + + /* Get results. */ + std::vector asrClassificationResult; + auto& asrClassifier = ctx.Get("asrclassifier"); + asrClassifier.GetClassificationResults( + asrOutputTensor, asrClassificationResult, + ctx.Get&>("asrlabels"), 1); + + asrResults.emplace_back(asr::AsrResult(asrClassificationResult, + (audioDataSlider.Index() * + asrAudioParamsSecondsPerSample * + asrAudioParamsWinStride), + audioDataSlider.Index(), asrScoreThreshold)); + +#if VERIFY_TEST_OUTPUT + arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]); +#endif /* VERIFY_TEST_OUTPUT */ + + /* Erase */ + str_inf = std::string(str_inf.size(), ' '); + platform.data_psn->present_data_text( + str_inf.c_str(), str_inf.size(), + dataPsnTxtInfStartX, dataPsnTxtInfStartY, false); + } + if (!_PresentInferenceResult(platform, asrResults)) { + return false; + } + + return true; + } + + /* Audio inference classification handler. */ + bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll) + { + auto& platform = ctx.Get("platform"); + platform.data_psn->clear(COLOR_BLACK); + + /* If the request has a valid size, set the audio index. */ + if (clipIndex < NUMBER_OF_FILES) { + if (!_SetAppCtxClipIdx(ctx, clipIndex)) { + return false; + } + } + + auto startClipIdx = ctx.Get("clipIndex"); + + do { + KWSOutput kwsOutput = doKws(ctx); + if (!kwsOutput.executionSuccess) { + return false; + } + + if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) { + info("Keyword spotted\n"); + if(!doAsr(ctx, kwsOutput)) { + printf_err("ASR failed"); + return false; + } + } + + _IncrementAppCtxClipIdx(ctx); + + } while (runAll && ctx.Get("clipIndex") != startClipIdx); + + return true; + } + + static void _IncrementAppCtxClipIdx(ApplicationContext& ctx) + { + auto curAudioIdx = ctx.Get("clipIndex"); + + if (curAudioIdx + 1 >= NUMBER_OF_FILES) { + ctx.Set("clipIndex", 0); + return; + } + ++curAudioIdx; + ctx.Set("clipIndex", curAudioIdx); + } + + static bool _SetAppCtxClipIdx(ApplicationContext& ctx, const uint32_t idx) + { + if (idx >= NUMBER_OF_FILES) { + printf_err("Invalid idx %u (expected less than %u)\n", + idx, NUMBER_OF_FILES); + return false; + } + ctx.Set("clipIndex", idx); + return true; + } + + static bool _PresentInferenceResult(hal_platform& platform, + std::vector& results) + { + constexpr uint32_t dataPsnTxtStartX1 = 20; + constexpr uint32_t dataPsnTxtStartY1 = 30; + constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */ + + platform.data_psn->set_text_color(COLOR_GREEN); + + /* Display each result. */ + uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr; + + for (uint32_t i = 0; i < results.size(); ++i) { + + std::string topKeyword{""}; + float score = 0.f; + + if (results[i].m_resultVec.size()) { + topKeyword = results[i].m_resultVec[0].m_label; + score = results[i].m_resultVec[0].m_normalisedVal; + } + + std::string resultStr = + std::string{"@"} + std::to_string(results[i].m_timeStamp) + + std::string{"s: "} + topKeyword + std::string{" ("} + + std::to_string(static_cast(score * 100)) + std::string{"%)"}; + + platform.data_psn->present_data_text( + resultStr.c_str(), resultStr.size(), + dataPsnTxtStartX1, rowIdx1, 0); + rowIdx1 += dataPsnTxtYIncr; + + info("For timestamp: %f (inference #: %u); threshold: %f\n", + results[i].m_timeStamp, results[i].m_inferenceNumber, + results[i].m_threshold); + for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) { + info("\t\tlabel @ %u: %s, score: %f\n", j, + results[i].m_resultVec[j].m_label.c_str(), + results[i].m_resultVec[j].m_normalisedVal); + } + } + + return true; + } + + static bool _PresentInferenceResult(hal_platform& platform, std::vector& results) + { + constexpr uint32_t dataPsnTxtStartX1 = 20; + constexpr uint32_t dataPsnTxtStartY1 = 80; + constexpr bool allow_multiple_lines = true; + + platform.data_psn->set_text_color(COLOR_GREEN); + + /* Results from multiple inferences should be combined before processing. */ + std::vector combinedResults; + for (auto& result : results) { + combinedResults.insert(combinedResults.end(), + result.m_resultVec.begin(), + result.m_resultVec.end()); + } + + for (auto& result : results) { + /* Get the final result string using the decoder. */ + std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec); + + info("Result for inf %u: %s\n", result.m_inferenceNumber, + infResultStr.c_str()); + } + + std::string finalResultStr = audio::asr::DecodeOutput(combinedResults); + + platform.data_psn->present_data_text( + finalResultStr.c_str(), finalResultStr.size(), + dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines); + + info("Final result: %s\n", finalResultStr.c_str()); + return true; + } + + /** + * @brief Generic feature calculator factory. + * + * Returns lambda function to compute features using features cache. + * Real features math is done by a lambda function provided as a parameter. + * Features are written to input tensor memory. + * + * @tparam T feature vector type. + * @param inputTensor model input tensor pointer. + * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap. + * @param compute features calculator function. + * @return lambda function to compute features. + **/ + template + std::function&, size_t, bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize, + std::function (std::vector& )> compute) + { + /* Feature cache to be captured by lambda function. */ + static std::vector> featureCache = std::vector>(cacheSize); + + return [=](std::vector& audioDataWindow, + size_t index, + bool useCache, + size_t featuresOverlapIndex) + { + T* tensorData = tflite::GetTensorData(inputTensor); + std::vector features; + + /* Reuse features from cache if cache is ready and sliding windows overlap. + * Overlap is in the beginning of sliding window with a size of a feature cache. + */ + if (useCache && index < featureCache.size()) { + features = std::move(featureCache[index]); + } else { + features = std::move(compute(audioDataWindow)); + } + auto size = features.size(); + auto sizeBytes = sizeof(T) * size; + std::memcpy(tensorData + (index * size), features.data(), sizeBytes); + + /* Start renewing cache as soon iteration goes out of the windows overlap. */ + if (index >= featuresOverlapIndex) { + featureCache[index - featuresOverlapIndex] = std::move(features); + } + }; + } + + template std::function&, size_t , bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector& )> compute); + + template std::function&, size_t , bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector& )> compute); + + template std::function&, size_t , bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function (std::vector& )> compute); + + template std::function&, size_t, bool, size_t)> + _FeatureCalc(TfLiteTensor* inputTensor, + size_t cacheSize, + std::function(std::vector&)> compute); + + + static std::function&, int, bool, size_t)> + GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + { + std::function&, size_t, bool, size_t)> mfccFeatureCalc; + + TfLiteQuantization quant = inputTensor->quantization; + + if (kTfLiteAffineQuantization == quant.type) { + + auto* quantParams = (TfLiteAffineQuantization*) quant.params; + const float quantScale = quantParams->scale->data[0]; + const int quantOffset = quantParams->zero_point->data[0]; + + switch (inputTensor->type) { + case kTfLiteInt8: { + mfccFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &mfcc](std::vector& audioDataWindow) { + return mfcc.MfccComputeQuant(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + case kTfLiteUInt8: { + mfccFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &mfcc](std::vector& audioDataWindow) { + return mfcc.MfccComputeQuant(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + case kTfLiteInt16: { + mfccFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [=, &mfcc](std::vector& audioDataWindow) { + return mfcc.MfccComputeQuant(audioDataWindow, + quantScale, + quantOffset); + } + ); + break; + } + default: + printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type)); + } + + + } else { + mfccFeatureCalc = mfccFeatureCalc = _FeatureCalc(inputTensor, + cacheSize, + [&mfcc](std::vector& audioDataWindow) { + return mfcc.MfccCompute(audioDataWindow); + }); + } + return mfccFeatureCalc; + } +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterMfcc.cc b/source/use_case/kws_asr/src/Wav2LetterMfcc.cc new file mode 100644 index 0000000..80e4a26 --- /dev/null +++ b/source/use_case/kws_asr/src/Wav2LetterMfcc.cc @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterMfcc.hpp" + +#include "PlatformMath.hpp" + +#include + +namespace arm { +namespace app { +namespace audio { + + bool Wav2LetterMFCC::ApplyMelFilterBank( + std::vector& fftVec, + std::vector>& melFilterBank, + std::vector& filterBankFilterFirst, + std::vector& filterBankFilterLast, + std::vector& melEnergies) + { + const size_t numBanks = melEnergies.size(); + + if (numBanks != filterBankFilterFirst.size() || + numBanks != filterBankFilterLast.size()) { + printf_err("unexpected filter bank lengths\n"); + return false; + } + + for (size_t bin = 0; bin < numBanks; ++bin) { + auto filterBankIter = melFilterBank[bin].begin(); + float melEnergy = 1e-10; /* Avoid log of zero at later stages, same value used in librosa. */ + const int32_t firstIndex = filterBankFilterFirst[bin]; + const int32_t lastIndex = filterBankFilterLast[bin]; + + for (int32_t i = firstIndex; i <= lastIndex; ++i) { + melEnergy += (*filterBankIter++ * fftVec[i]); + } + + melEnergies[bin] = melEnergy; + } + + return true; + } + + void Wav2LetterMFCC::ConvertToLogarithmicScale( + std::vector& melEnergies) + { + float maxMelEnergy = -FLT_MAX; + + /* Container for natural logarithms of mel energies. */ + std::vector vecLogEnergies(melEnergies.size(), 0.f); + + /* Because we are taking natural logs, we need to multiply by log10(e). + * Also, for wav2letter model, we scale our log10 values by 10. */ + constexpr float multiplier = 10.0 * /* Default scalar. */ + 0.4342944819032518; /* log10f(std::exp(1.0))*/ + + /* Take log of the whole vector. */ + math::MathUtils::VecLogarithmF32(melEnergies, vecLogEnergies); + + /* Scale the log values and get the max. */ + for (auto iterM = melEnergies.begin(), iterL = vecLogEnergies.begin(); + iterM != melEnergies.end(); ++iterM, ++iterL) { + + *iterM = *iterL * multiplier; + + /* Save the max mel energy. */ + if (*iterM > maxMelEnergy) { + maxMelEnergy = *iterM; + } + } + + /* Clamp the mel energies. */ + constexpr float maxDb = 80.0; + const float clampLevelLowdB = maxMelEnergy - maxDb; + for (auto iter = melEnergies.begin(); iter != melEnergies.end(); ++iter) { + *iter = std::max(*iter, clampLevelLowdB); + } + } + + std::vector Wav2LetterMFCC::CreateDCTMatrix( + const int32_t inputLength, + const int32_t coefficientCount) + { + std::vector dctMatix(inputLength * coefficientCount); + + /* Orthonormal normalization. */ + const float normalizerK0 = 2 * math::MathUtils::SqrtF32(1.0f / + static_cast(4*inputLength)); + const float normalizer = 2 * math::MathUtils::SqrtF32(1.0f / + static_cast(2*inputLength)); + + const float angleIncr = M_PI/inputLength; + float angle = angleIncr; /* We start using it at k = 1 loop. */ + + /* First row of DCT will use normalizer K0 */ + for (int32_t n = 0; n < inputLength; ++n) { + dctMatix[n] = normalizerK0 /* cos(0) = 1 */; + } + + /* Second row (index = 1) onwards, we use standard normalizer. */ + for (int32_t k = 1, m = inputLength; k < coefficientCount; ++k, m += inputLength) { + for (int32_t n = 0; n < inputLength; ++n) { + dctMatix[m+n] = normalizer * + math::MathUtils::CosineF32((n + 0.5f) * angle); + } + angle += angleIncr; + } + return dctMatix; + } + + float Wav2LetterMFCC::GetMelFilterBankNormaliser( + const float& leftMel, + const float& rightMel, + const bool useHTKMethod) + { + /* Slaney normalization for mel weights. */ + return (2.0f / (MFCC::InverseMelScale(rightMel, useHTKMethod) - + MFCC::InverseMelScale(leftMel, useHTKMethod))); + } + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ diff --git a/source/use_case/kws_asr/src/Wav2LetterModel.cc b/source/use_case/kws_asr/src/Wav2LetterModel.cc new file mode 100644 index 0000000..2114a3f --- /dev/null +++ b/source/use_case/kws_asr/src/Wav2LetterModel.cc @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterModel.hpp" + +#include "hal.h" + +namespace arm { +namespace app { +namespace asr { + extern uint8_t* GetModelPointer(); + extern size_t GetModelLen(); +} +} /* namespace app */ +} /* namespace arm */ + +const tflite::MicroOpResolver& arm::app::Wav2LetterModel::GetOpResolver() +{ + return this->_m_opResolver; +} + +bool arm::app::Wav2LetterModel::EnlistOperations() +{ + this->_m_opResolver.AddConv2D(); + this->_m_opResolver.AddMul(); + this->_m_opResolver.AddMaximum(); + this->_m_opResolver.AddReshape(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->_m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +const uint8_t* arm::app::Wav2LetterModel::ModelPointer() +{ + return arm::app::asr::GetModelPointer(); +} + +size_t arm::app::Wav2LetterModel::ModelSize() +{ + return arm::app::asr::GetModelLen(); +} \ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc new file mode 100644 index 0000000..b173968 --- /dev/null +++ b/source/use_case/kws_asr/src/Wav2LetterPostprocess.cc @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterPostprocess.hpp" + +#include "Wav2LetterModel.hpp" + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + Postprocess::Postprocess(const uint32_t contextLen, + const uint32_t innerLen, + const uint32_t blankTokenIdx) + : _m_contextLen(contextLen), + _m_innerLen(innerLen), + _m_totalLen(2 * this->_m_contextLen + this->_m_innerLen), + _m_countIterations(0), + _m_blankTokenIdx(blankTokenIdx) + {} + + bool Postprocess::Invoke(TfLiteTensor* tensor, + const uint32_t axisIdx, + const bool lastIteration) + { + /* Basic checks. */ + if (!this->_IsInputValid(tensor, axisIdx)) { + return false; + } + + /* Irrespective of tensor type, we use unsigned "byte" */ + uint8_t* ptrData = tflite::GetTensorData(tensor); + const uint32_t elemSz = this->_GetTensorElementSize(tensor); + + /* Other sanity checks. */ + if (0 == elemSz) { + printf_err("Tensor type not supported for post processing\n"); + return false; + } else if (elemSz * this->_m_totalLen > tensor->bytes) { + printf_err("Insufficient number of tensor bytes\n"); + return false; + } + + /* Which axis do we need to process? */ + switch (axisIdx) { + case arm::app::Wav2LetterModel::ms_outputRowsIdx: + return this->_EraseSectionsRowWise(ptrData, + elemSz * tensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx], + lastIteration); + default: + printf_err("Unsupported axis index: %u\n", axisIdx); + } + + return false; + } + + bool Postprocess::_IsInputValid(TfLiteTensor* tensor, + const uint32_t axisIdx) const + { + if (nullptr == tensor) { + return false; + } + + if (static_cast(axisIdx) >= tensor->dims->size) { + printf_err("Invalid axis index: %u; Max: %d\n", + axisIdx, tensor->dims->size); + return false; + } + + if (static_cast(this->_m_totalLen) != + tensor->dims->data[axisIdx]) { + printf_err("Unexpected tensor dimension for axis %d, \n", + tensor->dims->data[axisIdx]); + return false; + } + + return true; + } + + uint32_t Postprocess::_GetTensorElementSize(TfLiteTensor* tensor) + { + switch(tensor->type) { + case kTfLiteUInt8: + return 1; + case kTfLiteInt8: + return 1; + case kTfLiteInt16: + return 2; + case kTfLiteInt32: + return 4; + case kTfLiteFloat32: + return 4; + default: + printf_err("Unsupported tensor type %s\n", + TfLiteTypeGetName(tensor->type)); + } + + return 0; + } + + bool Postprocess::_EraseSectionsRowWise( + uint8_t* ptrData, + const uint32_t strideSzBytes, + const bool lastIteration) + { + /* In this case, the "zero-ing" is quite simple as the region + * to be zeroed sits in contiguous memory (row-major). */ + const uint32_t eraseLen = strideSzBytes * this->_m_contextLen; + + /* Erase left context? */ + if (this->_m_countIterations > 0) { + /* Set output of each classification window to the blank token. */ + std::memset(ptrData, 0, eraseLen); + for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) { + ptrData[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1; + } + } + + /* Erase right context? */ + if (false == lastIteration) { + uint8_t * rightCtxPtr = ptrData + (strideSzBytes * (this->_m_contextLen + this->_m_innerLen)); + /* Set output of each classification window to the blank token. */ + std::memset(rightCtxPtr, 0, eraseLen); + for (size_t windowIdx = 0; windowIdx < this->_m_contextLen; windowIdx++) { + rightCtxPtr[windowIdx*strideSzBytes + this->_m_blankTokenIdx] = 1; + } + } + + if (lastIteration) { + this->_m_countIterations = 0; + } else { + ++this->_m_countIterations; + } + + return true; + } + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc new file mode 100644 index 0000000..613ddb0 --- /dev/null +++ b/source/use_case/kws_asr/src/Wav2LetterPreprocess.cc @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Wav2LetterPreprocess.hpp" + +#include "PlatformMath.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include +#include + +namespace arm { +namespace app { +namespace audio { +namespace asr { + + Preprocess::Preprocess( + const uint32_t numMfccFeatures, + const uint32_t windowLen, + const uint32_t windowStride, + const uint32_t numMfccVectors): + _m_mfcc(numMfccFeatures, windowLen), + _m_mfccBuf(numMfccFeatures, numMfccVectors), + _m_delta1Buf(numMfccFeatures, numMfccVectors), + _m_delta2Buf(numMfccFeatures, numMfccVectors), + _m_windowLen(windowLen), + _m_windowStride(windowStride), + _m_numMfccFeats(numMfccFeatures), + _m_numFeatVectors(numMfccVectors), + _m_window() + { + if (numMfccFeatures > 0 && windowLen > 0) { + this->_m_mfcc.Init(); + } + } + + bool Preprocess::Invoke( + const int16_t* audioData, + const uint32_t audioDataLen, + TfLiteTensor* tensor) + { + this->_m_window = SlidingWindow( + audioData, audioDataLen, + this->_m_windowLen, this->_m_windowStride); + + uint32_t mfccBufIdx = 0; + + std::fill(_m_mfccBuf.begin(), _m_mfccBuf.end(), 0.f); + std::fill(_m_delta1Buf.begin(), _m_delta1Buf.end(), 0.f); + std::fill(_m_delta2Buf.begin(), _m_delta2Buf.end(), 0.f); + + /* While we can slide over the window. */ + while (this->_m_window.HasNext()) { + const int16_t* mfccWindow = this->_m_window.Next(); + auto mfccAudioData = std::vector( + mfccWindow, + mfccWindow + this->_m_windowLen); + auto mfcc = this->_m_mfcc.MfccCompute(mfccAudioData); + for (size_t i = 0; i < this->_m_mfccBuf.size(0); ++i) { + this->_m_mfccBuf(i, mfccBufIdx) = mfcc[i]; + } + ++mfccBufIdx; + } + + /* Pad MFCC if needed by adding MFCC for zeros. */ + if (mfccBufIdx != this->_m_numFeatVectors) { + std::vector zerosWindow = std::vector(this->_m_windowLen, 0); + std::vector mfccZeros = this->_m_mfcc.MfccCompute(zerosWindow); + + while (mfccBufIdx != this->_m_numFeatVectors) { + memcpy(&this->_m_mfccBuf(0, mfccBufIdx), + mfccZeros.data(), sizeof(float) * _m_numMfccFeats); + ++mfccBufIdx; + } + } + + /* Compute first and second order deltas from MFCCs. */ + this->_ComputeDeltas(this->_m_mfccBuf, + this->_m_delta1Buf, + this->_m_delta2Buf); + + /* Normalise. */ + this->_Normalise(); + + /* Quantise. */ + QuantParams quantParams = GetTensorQuantParams(tensor); + + if (0 == quantParams.scale) { + printf_err("Quantisation scale can't be 0\n"); + return false; + } + + switch(tensor->type) { + case kTfLiteUInt8: + return this->_Quantise( + tflite::GetTensorData(tensor), tensor->bytes, + quantParams.scale, quantParams.offset); + case kTfLiteInt8: + return this->_Quantise( + tflite::GetTensorData(tensor), tensor->bytes, + quantParams.scale, quantParams.offset); + default: + printf_err("Unsupported tensor type %s\n", + TfLiteTypeGetName(tensor->type)); + } + + return false; + } + + bool Preprocess::_ComputeDeltas(Array2d& mfcc, + Array2d& delta1, + Array2d& delta2) + { + const std::vector delta1Coeffs = + {6.66666667e-02, 5.00000000e-02, 3.33333333e-02, + 1.66666667e-02, -3.46944695e-18, -1.66666667e-02, + -3.33333333e-02, -5.00000000e-02, -6.66666667e-02}; + + const std::vector delta2Coeffs = + {0.06060606, 0.01515152, -0.01731602, + -0.03679654, -0.04329004, -0.03679654, + -0.01731602, 0.01515152, 0.06060606}; + + if (delta1.size(0) == 0 || delta2.size(0) != delta1.size(0) || + mfcc.size(0) == 0 || mfcc.size(1) == 0) { + return false; + } + + /* Get the middle index; coeff vec len should always be odd. */ + const size_t coeffLen = delta1Coeffs.size(); + const size_t fMidIdx = (coeffLen - 1)/2; + const size_t numFeatures = mfcc.size(0); + const size_t numFeatVectors = mfcc.size(1); + + /* Iterate through features in MFCC vector. */ + for (size_t i = 0; i < numFeatures; ++i) { + /* For each feature, iterate through time (t) samples representing feature evolution and + * calculate d/dt and d^2/dt^2, using 1d convolution with differential kernels. + * Convolution padding = valid, result size is `time length - kernel length + 1`. + * The result is padded with 0 from both sides to match the size of initial time samples data. + * + * For the small filter, conv1d implementation as a simple loop is efficient enough. + * Filters of a greater size would need CMSIS-DSP functions to be used, like arm_fir_f32. + */ + + for (size_t j = fMidIdx; j < numFeatVectors - fMidIdx; ++j) { + float d1 = 0; + float d2 = 0; + const size_t mfccStIdx = j - fMidIdx; + + for (size_t k = 0, m = coeffLen - 1; k < coeffLen; ++k, --m) { + + d1 += mfcc(i,mfccStIdx + k) * delta1Coeffs[m]; + d2 += mfcc(i,mfccStIdx + k) * delta2Coeffs[m]; + } + + delta1(i,j) = d1; + delta2(i,j) = d2; + } + } + + return true; + } + + float Preprocess::_GetMean(Array2d& vec) + { + return math::MathUtils::MeanF32(vec.begin(), vec.totalSize()); + } + + float Preprocess::_GetStdDev(Array2d& vec, const float mean) + { + return math::MathUtils::StdDevF32(vec.begin(), vec.totalSize(), mean); + } + + void Preprocess::_NormaliseVec(Array2d& vec) + { + auto mean = Preprocess::_GetMean(vec); + auto stddev = Preprocess::_GetStdDev(vec, mean); + + debug("Mean: %f, Stddev: %f\n", mean, stddev); + if (stddev == 0) { + std::fill(vec.begin(), vec.end(), 0); + } else { + const float stddevInv = 1.f/stddev; + const float normalisedMean = mean/stddev; + + auto NormalisingFunction = [=](float& value) { + value = value * stddevInv - normalisedMean; + }; + std::for_each(vec.begin(), vec.end(), NormalisingFunction); + } + } + + void Preprocess::_Normalise() + { + Preprocess::_NormaliseVec(this->_m_mfccBuf); + Preprocess::_NormaliseVec(this->_m_delta1Buf); + Preprocess::_NormaliseVec(this->_m_delta2Buf); + } + + float Preprocess::_GetQuantElem( + const float elem, + const float quantScale, + const int quantOffset, + const float minVal, + const float maxVal) + { + float val = std::round((elem/quantScale) + quantOffset); + return std::min(std::max(val, minVal), maxVal); + } + +} /* namespace asr */ +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ \ No newline at end of file diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake new file mode 100644 index 0000000..f15bc73 --- /dev/null +++ b/source/use_case/kws_asr/usecase.cmake @@ -0,0 +1,259 @@ +#---------------------------------------------------------------------------- +# Copyright (c) 2021 Arm Limited. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#---------------------------------------------------------------------------- + +# If the path to a directory or source file has been defined, +# get the type here (FILEPATH or PATH): +if (DEFINED ${use_case}_FILE_PATH) + get_path_type(${${use_case}_FILE_PATH} PATH_TYPE) + + # Set the default type if path is not a dir or file path (or undefined) + if (NOT ${PATH_TYPE} STREQUAL PATH AND NOT ${PATH_TYPE} STREQUAL FILEPATH) + message(FATAL_ERROR "Invalid ${use_case}_FILE_PATH. It should be a dir or file path.") + endif() +else() + # Default is a directory path + set(PATH_TYPE PATH) +endif() + +message(STATUS "${use_case}_FILE_PATH is of type: ${PATH_TYPE}") + +USER_OPTION(${use_case}_FILE_PATH "Directory with WAV files, or path to a single WAV file, to use in the evaluation application." + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/samples/ + ${PATH_TYPE}) + +USER_OPTION(${use_case}_AUDIO_RATE "Specify the target sampling rate. Default is 16000." + 16000 + STRING) + +USER_OPTION(${use_case}_AUDIO_MONO "Specify if the audio needs to be converted to mono. Default is ON." + ON + BOOL) + +USER_OPTION(${use_case}_AUDIO_OFFSET "Specify the offset to start reading after this time (in seconds). Default is 0." + 0 + STRING) + +USER_OPTION(${use_case}_AUDIO_DURATION "Specify the audio duration to load (in seconds). If set to 0 the entire audio will be processed." + 0 + STRING) + +USER_OPTION(${use_case}_AUDIO_RES_TYPE "Specify re-sampling algorithm to use. By default is 'kaiser_best'." + kaiser_best + STRING) + +USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples to use. By default is 16000, if the audio is shorter will be automatically padded." + 16000 + STRING) + +# Generate audio .cc files: +generate_audio_code(${${use_case}_FILE_PATH} ${SRC_GEN_DIR} ${INC_GEN_DIR} + ${${use_case}_AUDIO_RATE} + ${${use_case}_AUDIO_MONO} + ${${use_case}_AUDIO_OFFSET} + ${${use_case}_AUDIO_DURATION} + ${${use_case}_AUDIO_RES_TYPE} + ${${use_case}_AUDIO_MIN_SAMPLES}) + +# Generate kws labels file: +USER_OPTION(${use_case}_LABELS_TXT_FILE_KWS "Labels' txt file for the chosen model." + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/ds_cnn_labels.txt + FILEPATH) + +set(${use_case}_LABELS_CPP_FILE_KWS Labels_dscnn) +generate_labels_code( + INPUT "${${use_case}_LABELS_TXT_FILE_KWS}" + DESTINATION_SRC ${SRC_GEN_DIR} + DESTINATION_HDR ${INC_GEN_DIR} + OUTPUT_FILENAME "${${use_case}_LABELS_CPP_FILE_KWS}" + NAMESPACE "arm" "app" "kws" +) + +# Generate asr labels file: +USER_OPTION(${use_case}_LABELS_TXT_FILE_ASR "Labels' txt file for the chosen model." + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/labels_wav2letter.txt + FILEPATH) + +set(${use_case}_LABELS_CPP_FILE_ASR Labels_wav2letter) +generate_labels_code( + INPUT "${${use_case}_LABELS_TXT_FILE_ASR}" + DESTINATION_SRC ${SRC_GEN_DIR} + DESTINATION_HDR ${INC_GEN_DIR} + OUTPUT_FILENAME "${${use_case}_LABELS_CPP_FILE_ASR}" + NAMESPACE "arm" "app" "asr" +) + +USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen model" + 0x00200000 + STRING) + +USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_KWS "Specify the score threshold [0.0, 1.0) that must be applied to the KWS results for a label to be deemed valid." + 0.9 + STRING) + +USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [0.0, 1.0) that must be applied to the ASR results for a label to be deemed valid." + 0.5 + STRING) + +# If there is no tflite file pointed to +if (NOT DEFINED ${use_case}_MODEL_TFLITE_PATH_KWS) + + set(SUB_USECASE_KWS "kws") + set(MODEL_FILENAME_KWS ds_cnn_clustered_int8.tflite) + set(MODEL_RESOURCES_DIR_KWS ${DOWNLOAD_DEP_DIR}/${use_case}) + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR_KWS}) + set(DEFAULT_MODEL_PATH_KWS ${MODEL_RESOURCES_DIR_KWS}/${MODEL_FILENAME_KWS}) + + # Download the default model + set(ZOO_COMMON_SUBPATH_KWS "models/keyword_spotting/ds_cnn_large/tflite_clustered_int8") + set(ZOO_MODEL_SUBPATH_KWS "${ZOO_COMMON_SUBPATH_KWS}/${MODEL_FILENAME_KWS}") + + download_file_from_modelzoo(${ZOO_MODEL_SUBPATH_KWS} ${DEFAULT_MODEL_PATH_KWS}) + + if (ETHOS_U55_ENABLED) + message(STATUS + "Ethos-U55 is enabled, but the model downloaded is not optimized by vela. " + "To use Ethos-U55 acceleration, optimise the downloaded model and pass it " + "as ${use_case}_MODEL_TFLITE_PATH_KWS to the CMake configuration.") + endif() + + if (${TARGET_PLATFORM} STREQUAL native) + + # Download test vectors + set(ZOO_TEST_IFM_SUBPATH_KWS "${ZOO_COMMON_SUBPATH_KWS}/testing_input/input_2/0.npy") + set(ZOO_TEST_OFM_SUBPATH_KWS "${ZOO_COMMON_SUBPATH_KWS}/testing_output/Identity/0.npy") + + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR_KWS}/${SUB_USECASE_KWS}) + set(${use_case}_TEST_IFM ${MODEL_RESOURCES_DIR_KWS}/${SUB_USECASE_KWS}/ifm0.npy CACHE FILEPATH + "Input test vector for ${use_case}-${SUB_USECASE_KWS}") + set(${use_case}_TEST_OFM ${MODEL_RESOURCES_DIR_KWS}/${SUB_USECASE_KWS}/ofm0.npy CACHE FILEPATH + "Input test vector for ${use_case}-${SUB_USECASE_KWS}.") + + download_file_from_modelzoo(${ZOO_TEST_IFM_SUBPATH_KWS} ${${use_case}_TEST_IFM}) + download_file_from_modelzoo(${ZOO_TEST_OFM_SUBPATH_KWS} ${${use_case}_TEST_OFM}) + set(TEST_SRC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/src) + set(TEST_INC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/include) + file(MAKE_DIRECTORY ${TEST_SRC_GEN_DIR} ${TEST_INC_GEN_DIR}) + + generate_test_data_code( + INPUT_DIR "${DOWNLOAD_DEP_DIR}/${use_case}/${SUB_USECASE_KWS}" + DESTINATION_SRC ${TEST_SRC_GEN_DIR} + DESTINATION_HDR ${TEST_INC_GEN_DIR} + USECASE ${SUB_USECASE_KWS} + NAMESPACE "arm" "app" ${SUB_USECASE_KWS}) + endif() + +else() + set(DEFAULT_MODEL_PATH_KWS "N/A") +endif() + +set(EXTRA_MODEL_CODE_KWS + "/* Model parameters for ${use_case} */" + "extern const uint32_t g_NumMfcc = 10" + "extern const uint32_t g_NumAudioWins = 49" + "extern const int g_FrameLength = 640" + "extern const int g_FrameStride = 320" + "extern const float g_ScoreThreshold = ${${use_case}_MODEL_SCORE_THRESHOLD_KWS}" + ) + +# If there is no tflite file pointed to +if (NOT DEFINED ${use_case}_MODEL_TFLITE_PATH_ASR) + + set(SUB_USECASE_ASR "asr") + set(MODEL_FILENAME_ASR wav2letter_int8.tflite) + set(MODEL_RESOURCES_DIR_ASR ${DOWNLOAD_DEP_DIR}/${use_case}) + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR_ASR}) + set(DEFAULT_MODEL_PATH_ASR ${MODEL_RESOURCES_DIR_ASR}/${MODEL_FILENAME_ASR}) + + # Download the default model + set(ZOO_COMMON_SUBPATH_ASR "models/speech_recognition/wav2letter/tflite_int8") + set(ZOO_MODEL_SUBPATH_ASR "${ZOO_COMMON_SUBPATH_ASR}/${MODEL_FILENAME_ASR}") + + download_file_from_modelzoo(${ZOO_MODEL_SUBPATH_ASR} ${DEFAULT_MODEL_PATH_ASR}) + + if (ETHOS_U55_ENABLED) + message(STATUS + "Ethos-U55 is enabled, but the model downloaded is not optimized by vela. " + "To use Ethos-U55 acceleration, optimise the downloaded model and pass it " + "as ${use_case}_MODEL_TFLITE_PATH to the CMake configuration.") + endif() + + # If the target platform is native + if (${TARGET_PLATFORM} STREQUAL native) + + # Download test vectors + set(ZOO_TEST_IFM_SUBPATH_ASR "${ZOO_COMMON_SUBPATH_ASR}/testing_input/input_2_int8/0.npy") + set(ZOO_TEST_OFM_SUBPATH_ASR "${ZOO_COMMON_SUBPATH_ASR}/testing_output/Identity_int8/0.npy") + + file(MAKE_DIRECTORY ${MODEL_RESOURCES_DIR_ASR}/${SUB_USECASE_ASR}) + set(${use_case}_TEST_IFM_ASR ${MODEL_RESOURCES_DIR_ASR}/${SUB_USECASE_ASR}/ifm0.npy CACHE FILEPATH + "Input test vector for ${use_case}-${SUB_USECASE_ASR}") + set(${use_case}_TEST_OFM_ASR ${MODEL_RESOURCES_DIR_ASR}/${SUB_USECASE_ASR}/ofm0.npy CACHE FILEPATH + "Input test vector for ${use_case}-${SUB_USECASE_ASR}") + + download_file_from_modelzoo(${ZOO_TEST_IFM_SUBPATH_KWS} ${${use_case}_TEST_IFM_ASR}) + download_file_from_modelzoo(${ZOO_TEST_OFM_SUBPATH_KWS} ${${use_case}_TEST_OFM_ASR}) + + set(TEST_SRC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/src) + set(TEST_INC_GEN_DIR ${CMAKE_BINARY_DIR}/generated/${use_case}/tests/include) + file(MAKE_DIRECTORY ${TEST_SRC_GEN_DIR} ${TEST_INC_GEN_DIR}) + + # Generate test data files to be included in x86 tests + generate_test_data_code( + INPUT_DIR "${DOWNLOAD_DEP_DIR}/${use_case}/${SUB_USECASE_ASR}" + DESTINATION_SRC ${TEST_SRC_GEN_DIR} + DESTINATION_HDR ${TEST_INC_GEN_DIR} + USECASE ${SUB_USECASE_ASR} + NAMESPACE "arm" "app" ${SUB_USECASE_ASR}) + endif() + +else() + set(DEFAULT_MODEL_PATH_ASR "N/A") +endif() + +set(EXTRA_MODEL_CODE_ASR + "/* Model parameters for ${use_case} */" + "extern const int g_FrameLength = 512" + "extern const int g_FrameStride = 160" + "extern const int g_ctxLen = 98" + "extern const float g_ScoreThreshold = ${${use_case}_MODEL_SCORE_THRESHOLD_ASR}" + ) + +USER_OPTION(${use_case}_MODEL_TFLITE_PATH_KWS "NN models file to be used for KWS in the evaluation application. Model files must be in tflite format." + ${DEFAULT_MODEL_PATH_KWS} + FILEPATH + ) + +USER_OPTION(${use_case}_MODEL_TFLITE_PATH_ASR "NN models file to be used for ASR in the evaluation application. Model files must be in tflite format." + ${DEFAULT_MODEL_PATH_ASR} + FILEPATH + ) + +# Generate model file for KWS +generate_tflite_code( + MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH_KWS} + DESTINATION ${SRC_GEN_DIR} + EXPRESSIONS ${EXTRA_MODEL_CODE_KWS} + NAMESPACE "arm" "app" "kws" +) + +# and for ASR +generate_tflite_code( + MODEL_PATH ${${use_case}_MODEL_TFLITE_PATH_ASR} + DESTINATION ${SRC_GEN_DIR} + EXPRESSIONS ${EXTRA_MODEL_CODE_ASR} + NAMESPACE "arm" "app" "asr" +) -- cgit v1.2.1