From 76a1580861210e0310db23acbc29e1064ae30ead Mon Sep 17 00:00:00 2001 From: Kshitij Sisodia Date: Fri, 24 Dec 2021 11:05:11 +0000 Subject: MLECO-2599: Replace DSCNN with MicroNet for KWS Added SoftMax function to Mathutils to allow MicroNet to output probability as it does not nativelu have this layer. Minor refactoring to accommodate Softmax Calculations Extensive renaming and updating of documentation and resource download script. Added SoftMax function to Mathutils to allow MicroNet to output probability. Change-Id: I7cbbda1024d14b85c9ac1beea7ca8fbffd0b6eb5 Signed-off-by: Liam Barry --- Readme.md | 4 +- docs/documentation.md | 2 +- docs/quick_start.md | 32 +++--- docs/use_cases/kws.md | 71 +++++++------- docs/use_cases/kws_asr.md | 106 +++++++++----------- resources/kws/labels/ds_cnn_labels.txt | 12 --- resources/kws/labels/micronet_kws_labels.txt | 12 +++ resources/kws_asr/labels/ds_cnn_labels.txt | 12 --- resources/kws_asr/labels/micronet_kws_labels.txt | 12 +++ set_up_default_resources.py | 22 ++--- source/application/main/Classifier.cc | 105 ++++++++++---------- source/application/main/PlatformMath.cc | 20 ++++ source/application/main/include/Classifier.hpp | 23 ++++- source/application/main/include/PlatformMath.hpp | 7 ++ source/use_case/asr/include/AsrClassifier.hpp | 3 +- source/use_case/asr/src/AsrClassifier.cc | 3 +- source/use_case/kws/include/DsCnnMfcc.hpp | 50 ---------- source/use_case/kws/include/DsCnnModel.hpp | 59 ------------ source/use_case/kws/include/MicroNetKwsMfcc.hpp | 50 ++++++++++ source/use_case/kws/include/MicroNetKwsModel.hpp | 59 ++++++++++++ source/use_case/kws/src/DsCnnModel.cc | 58 ----------- source/use_case/kws/src/MainLoop.cc | 4 +- source/use_case/kws/src/MicroNetKwsModel.cc | 57 +++++++++++ source/use_case/kws/src/UseCaseHandler.cc | 24 +++-- source/use_case/kws/usecase.cmake | 9 +- source/use_case/kws_asr/include/AsrClassifier.hpp | 4 +- source/use_case/kws_asr/include/DsCnnMfcc.hpp | 51 ---------- source/use_case/kws_asr/include/DsCnnModel.hpp | 67 ------------- .../use_case/kws_asr/include/MicroNetKwsMfcc.hpp | 51 ++++++++++ .../use_case/kws_asr/include/MicroNetKwsModel.hpp | 66 +++++++++++++ source/use_case/kws_asr/src/AsrClassifier.cc | 3 +- source/use_case/kws_asr/src/DsCnnModel.cc | 67 ------------- source/use_case/kws_asr/src/MainLoop.cc | 10 +- source/use_case/kws_asr/src/MicroNetKwsModel.cc | 64 ++++++++++++ source/use_case/kws_asr/src/UseCaseHandler.cc | 20 ++-- source/use_case/kws_asr/usecase.cmake | 8 +- tests/common/ClassifierTests.cc | 4 +- tests/use_case/kws/InferenceTestDSCNN.cc | 107 --------------------- tests/use_case/kws/InferenceTestMicroNetKws.cc | 107 +++++++++++++++++++++ tests/use_case/kws/KWSHandlerTest.cc | 30 +++--- tests/use_case/kws/MfccTests.cc | 8 +- tests/use_case/kws_asr/InferenceTestDSCNN.cc | 105 -------------------- tests/use_case/kws_asr/InferenceTestMicroNetKws.cc | 105 ++++++++++++++++++++ tests/use_case/kws_asr/InitModels.cc | 6 +- tests/use_case/kws_asr/MfccTests.cc | 22 ++--- 45 files changed, 878 insertions(+), 843 deletions(-) delete mode 100644 resources/kws/labels/ds_cnn_labels.txt create mode 100644 resources/kws/labels/micronet_kws_labels.txt delete mode 100644 resources/kws_asr/labels/ds_cnn_labels.txt create mode 100644 resources/kws_asr/labels/micronet_kws_labels.txt delete mode 100644 source/use_case/kws/include/DsCnnMfcc.hpp delete mode 100644 source/use_case/kws/include/DsCnnModel.hpp create mode 100644 source/use_case/kws/include/MicroNetKwsMfcc.hpp create mode 100644 source/use_case/kws/include/MicroNetKwsModel.hpp delete mode 100644 source/use_case/kws/src/DsCnnModel.cc create mode 100644 source/use_case/kws/src/MicroNetKwsModel.cc delete mode 100644 source/use_case/kws_asr/include/DsCnnMfcc.hpp delete mode 100644 source/use_case/kws_asr/include/DsCnnModel.hpp create mode 100644 source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp create mode 100644 source/use_case/kws_asr/include/MicroNetKwsModel.hpp delete mode 100644 source/use_case/kws_asr/src/DsCnnModel.cc create mode 100644 source/use_case/kws_asr/src/MicroNetKwsModel.cc delete mode 100644 tests/use_case/kws/InferenceTestDSCNN.cc create mode 100644 tests/use_case/kws/InferenceTestMicroNetKws.cc delete mode 100644 tests/use_case/kws_asr/InferenceTestDSCNN.cc create mode 100644 tests/use_case/kws_asr/InferenceTestMicroNetKws.cc diff --git a/Readme.md b/Readme.md index 48a1773..6f95808 100644 --- a/Readme.md +++ b/Readme.md @@ -30,9 +30,9 @@ The example application at your disposal and the utilized models are listed in t | ML application | Description | Neural Network Model | | :----------------------------------: | :-----------------------------------------------------: | :----: | | [Image classification](./docs/use_cases/img_class.md) | Recognize the presence of objects in a given image | [Mobilenet V2](https://github.com/ARM-software/ML-zoo/tree/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8) | -| [Keyword spotting(KWS)](./docs/use_cases/kws.md) | Recognize the presence of a key word in a recording | [DS-CNN-L](https://github.com/ARM-software/ML-zoo/tree/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8) | +| [Keyword spotting(KWS)](./docs/use_cases/kws.md) | Recognize the presence of a key word in a recording | [MicroNet](https://github.com/ARM-software/ML-zoo/tree/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8) | | [Automated Speech Recognition(ASR)](./docs/use_cases/asr.md) | Transcribe words in a recording | [Wav2Letter](https://github.com/ARM-software/ML-zoo/tree/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_int8) | -| [KWS and ASR](./docs/use_cases/kws_asr.md) | Utilise Cortex-M and Ethos-U to transcribe words in a recording after a keyword was spotted | [DS-CNN-L](https://github.com/ARM-software/ML-zoo/tree/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8) [Wav2Letter](https://github.com/ARM-software/ML-zoo/tree/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_int8) | +| [KWS and ASR](./docs/use_cases/kws_asr.md) | Utilise Cortex-M and Ethos-U to transcribe words in a recording after a keyword was spotted | [MicroNet](https://github.com/ARM-software/ML-zoo/tree/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8) [Wav2Letter](https://github.com/ARM-software/ML-zoo/tree/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_int8) | | [Anomaly Detection](./docs/use_cases/ad.md) | Detecting abnormal behavior based on a sound recording of a machine | [MicroNet](https://github.com/ARM-software/ML-zoo/tree/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8/)| | [Visual Wake Word](./docs/use_cases/visual_wake_word.md) | Recognize if person is present in a given image | [MicroNet](https://github.com/ARM-software/ML-zoo/tree/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/vww4_128_128_INT8.tflite)| | [Noise Reduction](./docs/use_cases/noise_reduction.md) | Remove noise from audio while keeping speech intact | [RNNoise](https://github.com/ARM-software/ML-zoo/raw/a061600058097a2785d6f1f7785e5a2d2a142955/models/noise_suppression/RNNoise/tflite_int8) | diff --git a/docs/documentation.md b/docs/documentation.md index f1fab8c..2ec2028 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -211,7 +211,7 @@ What these folders contain: The models used in the use-cases implemented in this project can be downloaded from: [Arm ML-Zoo](https://github.com/ARM-software/ML-zoo). - [Mobilenet V2](https://github.com/ARM-software/ML-zoo/tree/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8) -- [DS-CNN](https://github.com/ARM-software/ML-zoo/tree/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b//models/keyword_spotting/ds_cnn_large/tflite_clustered_int8) +- [MicroNet for Keyword Spotting](https://github.com/ARM-software/ML-zoo/tree/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8) - [Wav2Letter](https://github.com/ARM-software/ML-zoo/tree/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8) - [MicroNet for Anomaly Detection](https://github.com/ARM-software/ML-zoo/tree/7c32b097f7d94aae2cd0b98a8ed5a3ba81e66b18/models/anomaly_detection/micronet_medium/tflite_int8) - [MicroNet for Visual Wake Word](https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/vww4_128_128_INT8.tflite) diff --git a/docs/quick_start.md b/docs/quick_start.md index 6522bdd..252b084 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -84,11 +84,11 @@ curl -L https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3 --output ./resources_downloaded/img_class/ifm0.npy curl -L https://github.com/ARM-software/ML-zoo/raw/e0aa361b03c738047b9147d1a50e3f2dcb13dbcb/models/image_classification/mobilenet_v2_1.0_224/tflite_int8/testing_output/MobilenetV2/Predictions/Reshape_11/0.npy \ --output ./resources_downloaded/img_class/ofm0.npy -curl -L https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/ds_cnn_clustered_int8.tflite \ - --output ./resources_downloaded/kws/ds_cnn_clustered_int8.tflite -curl -L https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_input/input_2/0.npy \ +curl -L https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/kws_micronet_m.tflite \ + --output ./resources_downloaded/kws/kws_micronet_m.tflite +curl -L https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/testing_input/input/0.npy \ --output ./resources_downloaded/kws/ifm0.npy -curl -L https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_output/Identity/0.npy \ +curl -L https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/testing_output/Identity/0.npy \ --output ./resources_downloaded/kws/ofm0.npy curl -L https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/wav2letter_pruned_int8.tflite \ --output ./resources_downloaded/kws_asr/wav2letter_pruned_int8.tflite @@ -96,11 +96,11 @@ curl -L https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59d --output ./resources_downloaded/kws_asr/asr/ifm0.npy curl -L https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity_int8/0.npy \ --output ./resources_downloaded/kws_asr/asr/ofm0.npy -curl -L https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/ds_cnn_clustered_int8.tflite \ - --output ./resources_downloaded/kws_asr/ds_cnn_clustered_int8.tflite -curl -L https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_input/input_2/0.npy \ +curl -L https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/kws_micronet_m.tflite \ + --output ./resources_downloaded/kws_asr/kws_micronet_m.tflite +curl -L https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/testing_input/input/0.npy \ --output ./resources_downloaded/kws_asr/kws/ifm0.npy -curl -L https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_output/Identity/0.npy \ +curl -L https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/testing_output/Identity/0.npy \ --output ./resources_downloaded/kws_asr/kws/ofm0.npy curl -L https://github.com/ARM-software/ML-zoo/raw/a061600058097a2785d6f1f7785e5a2d2a142955/models/noise_suppression/RNNoise/tflite_int8/rnnoise_INT8.tflite \ --output ./resources_downloaded/noise_reduction/rnnoise_INT8.tflite @@ -131,22 +131,22 @@ curl -L https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c curl -L https://github.com/ARM-software/ML-zoo/raw/7dd3b16bb84007daf88be8648983c07f3eb21140/models/visual_wake_words/micronet_vww4/tflite_int8/testing_output/Identity/0.npy \ --output ./resources_downloaded/vww/ofm0.npy -. resources_downloaded/env/bin/activate && vela resources_downloaded/kws/ds_cnn_clustered_int8.tflite \ +. resources_downloaded/env/bin/activate && vela resources_downloaded/kws/kws_micronet_m.tflite \ --accelerator-config=ethos-u55-128 \ --optimise Performance --config scripts/vela/default_vela.ini \ --memory-mode=Shared_Sram \ --system-config=Ethos_U55_High_End_Embedded \ --output-dir=resources_downloaded/kws \ --arena-cache-size=2097152 -mv resources_downloaded/kws/ds_cnn_clustered_int8_vela.tflite resources_downloaded/kws/ds_cnn_clustered_int8_vela_H128.tflite +mv resources_downloaded/kws/kws_micronet_m.tflite resources_downloaded/kws/kws_micronet_m_vela_H128.tflite -. resources_downloaded/env/bin/activate && vela resources_downloaded/kws/ds_cnn_clustered_int8.tflite \ +. resources_downloaded/env/bin/activate && vela resources_downloaded/kws/kws_micronet_m.tflite \ --accelerator-config=ethos-u65-256 \ --optimise Performance --config scripts/vela/default_vela.ini \ --memory-mode=Dedicated_Sram \ --system-config=Ethos_U65_High_End \ --output-dir=resources_downloaded/kws -mv resources_downloaded/kws/ds_cnn_clustered_int8_vela.tflite resources_downloaded/kws/ds_cnn_clustered_int8_vela_Y256.tflite +mv resources_downloaded/kws/kws_micronet_vela.tflite resources_downloaded/kws/kws_micronet_m_vela_Y256.tflite . resources_downloaded/env/bin/activate && vela resources_downloaded/kws_asr/wav2letter_int8.tflite \ --accelerator-config=ethos-u55-128 \ @@ -165,22 +165,22 @@ mv resources_downloaded/kws_asr/wav2letter_int8_vela.tflite resources_downloaded --output-dir=resources_downloaded/kws_asr mv resources_downloaded/kws_asr/wav2letter_int8_vela.tflite resources_downloaded/kws_asr/wav2letter_int8_vela_Y256.tflite -. resources_downloaded/env/bin/activate && vela resources_downloaded/kws_asr/ds_cnn_clustered_int8.tflite \ +. resources_downloaded/env/bin/activate && vela resources_downloaded/kws_asr/kws_micronet_m.tflite \ --accelerator-config=ethos-u55-128 \ --optimise Performance --config scripts/vela/default_vela.ini \ --memory-mode=Shared_Sram \ --system-config=Ethos_U55_High_End_Embedded \ --output-dir=resources_downloaded/kws_asr \ --arena-cache-size=2097152 -mv resources_downloaded/kws_asr/ds_cnn_clustered_int8_vela.tflite resources_downloaded/kws_asr/ds_cnn_clustered_int8_vela_H128.tflite +mv resources_downloaded/kws_asr/kws_micronet_m.tflite_vela.tflite resources_downloaded/kws_asr/kws_micronet_m.tflite_vela_H128.tflite -. resources_downloaded/env/bin/activate && vela resources_downloaded/kws_asr/ds_cnn_clustered_int8.tflite \ +. resources_downloaded/env/bin/activate && vela resources_downloaded/kws_asr/kws_micronet_m.tflite \ --accelerator-config=ethos-u65-256 \ --optimise Performance --config scripts/vela/default_vela.ini \ --memory-mode=Dedicated_Sram \ --system-config=Ethos_U65_High_End \ --output-dir=resources_downloaded/kws_asr -mv resources_downloaded/kws_asr/ds_cnn_clustered_int8_vela.tflite resources_downloaded/kws_asr/ds_cnn_clustered_int8_vela_Y256.tflite +mv resources_downloaded/kws_asr/kws_micronet_m.tflite_vela.tflite resources_downloaded/kws_asr/kws_micronet_m.tflite_vela_Y256.tflite . resources_downloaded/env/bin/activate && vela resources_downloaded/inference_runner/dnn_s_quantized.tflite \ --accelerator-config=ethos-u55-128 \ diff --git a/docs/use_cases/kws.md b/docs/use_cases/kws.md index 84eddd6..0c50fe5 100644 --- a/docs/use_cases/kws.md +++ b/docs/use_cases/kws.md @@ -23,7 +23,7 @@ Use-case code could be found in the following directory: [source/use_case/kws](. ### Preprocessing and feature extraction -The `DS-CNN` keyword spotting model that is used with the Code Samples expects audio data to be preprocessed in a +The `MicroNet` keyword spotting model that is used with the Code Samples expects audio data to be preprocessed in a specific way before performing an inference. Therefore, this section aims to provide an overview of the feature extraction process used. @@ -62,7 +62,7 @@ used. ### Postprocessing After an inference is complete, the word with the highest detected probability is output to console. Providing that the -probability is larger than a threshold value. The default is set to `0.9`. +probability is larger than a threshold value. The default is set to `0.7`. If multiple inferences are performed for an audio clip, then multiple results are output. @@ -107,7 +107,7 @@ In addition to the already specified build option in the main documentation, the this number, then it is padded with zeros. The default value is `16000`. - `kws_MODEL_SCORE_THRESHOLD`: Threshold value that must be applied to the inference results for a label to be deemed - valid. Goes from 0.00 to 1.0. The default is `0.9`. + valid. Goes from 0.00 to 1.0. The default is `0.7`. - `kws_ACTIVATION_BUF_SZ`: The intermediate, or activation, buffer size reserved for the NN model. By default, it is set to 2MiB and is enough for most models @@ -247,7 +247,7 @@ For further information: [Optimize model with Vela compiler](../sections/buildin To run the application with a custom model, you must provide a `labels_.txt` file of labels that are associated with the model. Each line of the file must correspond to one of the outputs in your model. Refer to the -provided `ds_cnn_labels.txt` file for an example. +provided `micronet_kws_labels.txt` file for an example. Then, you must set `kws_MODEL_TFLITE_PATH` to the location of the Vela processed model file and `kws_LABELS_TXT_FILE`to the location of the associated labels file. @@ -369,24 +369,24 @@ What the preceding choices do: INFO - Model INPUT tensors: INFO - tensor type is INT8 INFO - tensor occupies 490 bytes with dimensions - INFO - 0: 1 - INFO - 1: 1 - INFO - 2: 49 - INFO - 3: 10 + INFO - 0: 1 + INFO - 1: 49 + INFO - 2: 10 + INFO - 3: 1 INFO - Quant dimension: 0 - INFO - Scale[0] = 1.107164 - INFO - ZeroPoint[0] = 95 + INFO - Scale[0] = 0.201095 + INFO - ZeroPoint[0] = -5 INFO - Model OUTPUT tensors: - INFO - tensor type is INT8 - INFO - tensor occupies 12 bytes with dimensions - INFO - 0: 1 - INFO - 1: 12 + INFO - tensor type is INT8 + INFO - tensor occupies 12 bytes with dimensions + INFO - 0: 1 + INFO - 1: 12 INFO - Quant dimension: 0 - INFO - Scale[0] = 0.003906 - INFO - ZeroPoint[0] = -128 - INFO - Activation buffer (a.k.a tensor arena) size used: 72848 - INFO - Number of operators: 1 - INFO - Operator 0: ethos-u + INFO - Scale[0] = 0.056054 + INFO - ZeroPoint[0] = -54 + INFO - Activation buffer (a.k.a tensor arena) size used: 127068 + INFO - Number of operators: 0 + INFO - Operator 0: ethos-u ``` 5. List audio clips: Prints a list of pair ... indexes. The original filenames are embedded in the application, like so: @@ -405,18 +405,21 @@ Please select the first menu option to execute inference on the first file. The following example illustrates the output for classification: -```logINFO - Running inference on audio clip 0 => down.wav +```log + +INFO - Running inference on audio clip 0 => down.wav INFO - Inference 1/1 INFO - Final results: INFO - Total number of inferences: 1 -INFO - For timestamp: 0.000000 (inference #: 0); label: down, score: 0.996094; threshold: 0.900000 +INFO - For timestamp: 0.000000 (inference #: 0); label: down, score: 0.986182; threshold: 0.700000 INFO - Profile for Inference: -INFO - NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 217385 -INFO - NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 82607 -INFO - NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 59608 -INFO - NPU ACTIVE cycles: 680611 -INFO - NPU IDLE cycles: 561 -INFO - NPU TOTAL cycles: 681172 +INFO - NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 132130 +INFO - NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 48252 +INFO - NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 17544 +INFO - NPU ACTIVE cycles: 413814 +INFO - NPU IDLE cycles: 358 +INFO - NPU TOTAL cycles: 414172 + ``` On most systems running Fast Model, each inference takes under 30 seconds. @@ -425,22 +428,22 @@ The profiling section of the log shows that for this inference: - *Ethos-U* PMU report: - - 681,172 total cycle: The number of NPU cycles. + - 414,172 total cycle: The number of NPU cycles. - - 680,611 active cycles: The number of NPU cycles that were used for computation. + - 413,814 active cycles: The number of NPU cycles that were used for computation. - - 561 idle cycles: The number of cycles for which the NPU was idle. + - 358 idle cycles: The number of cycles for which the NPU was idle. - - 217,385 AXI0 read beats: The number of AXI beats with read transactions from the AXI0 bus. AXI0 is the bus where the + - 132,130 AXI0 read beats: The number of AXI beats with read transactions from the AXI0 bus. AXI0 is the bus where the *Ethos-U* NPU reads and writes to the computation buffers, activation buf, or tensor arenas. - - 82,607 write cycles: The number of AXI beats with write transactions to AXI0 bus. + - 48,252 write cycles: The number of AXI beats with write transactions to AXI0 bus. - - 59,608 AXI1 read beats: The number of AXI beats with read transactions from the AXI1 bus. AXI1 is the bus where the + - 17,544 AXI1 read beats: The number of AXI beats with read transactions from the AXI1 bus. AXI1 is the bus where the *Ethos-U* NPU reads the model. So, read-only. - For FPGA platforms, a CPU cycle count can also be enabled. However, do not use cycle counters for FVP, as the CPU model is not cycle-approximate or cycle-accurate. -> **Note:** The application prints the highest confidence score and the associated label from the `ds_cnn_labels.txt` +> **Note:** The application prints the highest confidence score and the associated label from the `micronet_kws_labels.txt` > file. diff --git a/docs/use_cases/kws_asr.md b/docs/use_cases/kws_asr.md index 42c9d3a..22f1e9d 100644 --- a/docs/use_cases/kws_asr.md +++ b/docs/use_cases/kws_asr.md @@ -44,7 +44,7 @@ By default, the KWS model is run purely on the CPU and **not** on the *Ethos-U55 #### Keyword Spotting Preprocessing -The `DS-CNN` keyword spotting model that is used with the Code Samples expects audio data to be preprocessed in a +The `MicroNet` keyword spotting model that is used with the Code Samples expects audio data to be preprocessed in a specific way before performing an inference. Therefore, this section aims to provide an overview of the feature extraction process used. @@ -455,43 +455,30 @@ What the preceding choices do: 4. Show NN model info: Prints information about the model data type, input, and output, tensor sizes: - ```log + ```log + INFO - Model info: INFO - Model INPUT tensors: INFO - tensor type is INT8 INFO - tensor occupies 490 bytes with dimensions - INFO - 0: 1 - INFO - 1: 1 - INFO - 2: 49 - INFO - 3: 10 + INFO - 0: 1 + INFO - 1: 49 + INFO - 2: 10 + INFO - 3: 1 INFO - Quant dimension: 0 - INFO - Scale[0] = 1.107164 - INFO - ZeroPoint[0] = 95 + INFO - Scale[0] = 0.201095 + INFO - ZeroPoint[0] = -5 INFO - Model OUTPUT tensors: - INFO - tensor type is INT8 - INFO - tensor occupies 12 bytes with dimensions - INFO - 0: 1 - INFO - 1: 12 + INFO - tensor type is INT8 + INFO - tensor occupies 12 bytes with dimensions + INFO - 0: 1 + INFO - 1: 12 INFO - Quant dimension: 0 - INFO - Scale[0] = 0.003906 - INFO - ZeroPoint[0] = -128 - INFO - Activation buffer (a.k.a tensor arena) size used: 123616 - INFO - Number of operators: 16 - INFO - Operator 0: RESHAPE - INFO - Operator 1: CONV_2D - INFO - Operator 2: DEPTHWISE_CONV_2D - INFO - Operator 3: CONV_2D - INFO - Operator 4: DEPTHWISE_CONV_2D - INFO - Operator 5: CONV_2D - INFO - Operator 6: DEPTHWISE_CONV_2D - INFO - Operator 7: CONV_2D - INFO - Operator 8: DEPTHWISE_CONV_2D - INFO - Operator 9: CONV_2D - INFO - Operator 10: DEPTHWISE_CONV_2D - INFO - Operator 11: CONV_2D - INFO - Operator 12: AVERAGE_POOL_2D - INFO - Operator 13: RESHAPE - INFO - Operator 14: FULLY_CONNECTED - INFO - Operator 15: SOFTMAX + INFO - Scale[0] = 0.056054 + INFO - ZeroPoint[0] = -54 + INFO - Activation buffer (a.k.a tensor arena) size used: 127068 + INFO - Number of operators: 1 + INFO - Operator 0: ethos-u + INFO - Model INPUT tensors: INFO - tensor type is INT8 INFO - tensor occupies 11544 bytes with dimensions @@ -511,9 +498,9 @@ What the preceding choices do: INFO - Quant dimension: 0 INFO - Scale[0] = 0.003906 INFO - ZeroPoint[0] = -128 - INFO - Activation buffer (a.k.a tensor arena) size used: 809808 + INFO - Activation buffer (a.k.a tensor arena) size used: 4184332 INFO - Number of operators: 1 - INFO - Operator 0: ethos-u + INFO - Operator 0: ethos-u ``` 5. List audio clips: Prints a list of pair ... indexes. The original filenames are embedded in the application, like so: @@ -534,27 +521,31 @@ INFO - KWS audio data window size 16000 INFO - Running KWS inference on audio clip 0 => yes_no_go_stop.wav INFO - Inference 1/7 INFO - For timestamp: 0.000000 (inference #: 0); threshold: 0.900000 -INFO - label @ 0: yes, score: 0.996094 +INFO - label @ 0: yes, score: 0.997407 INFO - Profile for Inference: -INFO - NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 217385 -INFO - NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 82607 -INFO - NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 59608 -INFO - NPU ACTIVE cycles: 680611 -INFO - NPU IDLE cycles: 561 -INFO - NPU TOTAL cycles: 681172 +INFO - NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 132130 +INFO - NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 48252 +INFO - NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 17544 +INFO - NPU ACTIVE cycles: 413814 +INFO - NPU IDLE cycles: 358 +INFO - NPU TOTAL cycles: 414172 INFO - Keyword spotted INFO - Inference 1/2 INFO - Inference 2/2 -INFO - Result for inf 0: no gow -INFO - Result for inf 1: stoppe -INFO - Final result: no gow stoppe +INFO - Result for inf 0: no go +INFO - Result for inf 1: stop +INFO - Final result: no go stop INFO - Profile for Inference: -INFO - NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 13520864 -INFO - NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2841970 -INFO - NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 2717670 -INFO - NPU ACTIVE cycles: 28909309 -INFO - NPU IDLE cycles: 863 -INFO - NPU TOTAL cycles: 28910172 +INFO - NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 8895431 +INFO - NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 1890168 +INFO - NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 1740069 +INFO - NPU ACTIVE cycles: 30164330 +INFO - NPU IDLE cycles: 342 +INFO - NPU TOTAL cycles: 30164672 +INFO - Main loop terminated. +INFO - program terminating... +INFO - releasing platform Arm Corstone-300 (SSE-300) + ``` It can take several minutes to complete one inference run. The average time is around 2-3 minutes. @@ -567,22 +558,21 @@ The profiling section of the log shows that for the ASR inference: - *Ethos-U* PMU report: - - 28,910,172 total cycle: The number of NPU cycles. + - 30,164,672 total cycle: The number of NPU cycles. - - 28,909,309 active cycles: The number of NPU cycles that were used for computation. + - 30,164,330 active cycles: The number of NPU cycles that were used for computation. - - 863 idle cycles: The number of cycles for which the NPU was idle. + - 342 idle cycles: The number of cycles for which the NPU was idle. - - 13,520,864 AXI0 read beats: The number of AXI beats with read transactions from the AXI0 bus. AXI0 is the bus where + - 8,895,431 AXI0 read beats: The number of AXI beats with read transactions from the AXI0 bus. AXI0 is the bus where the *Ethos-U* NPU reads and writes to the computation buffers, activation buf, or tensor arenas. - - 2,841,970 AXI0 write beats: The number of AXI beats with write transactions to AXI0 bus. + - 1,890,168 AXI0 write beats: The number of AXI beats with write transactions to AXI0 bus. - - 2,717,670 AXI1 read beats: The number of AXI beats with read transactions from the AXI1 bus. AXI1 is the bus where + - 1,740,069 AXI1 read beats: The number of AXI beats with read transactions from the AXI1 bus. AXI1 is the bus where the *Ethos-U55* NPU reads the model. So, read-only. - For FPGA platforms, a CPU cycle count can also be enabled. However, do not use cycle counters for FVP, as the CPU model is not cycle-approximate or cycle-accurate. -> **Note:** In this example, the KWS inference does *not* use the *Ethos-U55* and only runs on the CPU. Therefore, `0` -> Active NPU cycles are shown. + diff --git a/resources/kws/labels/ds_cnn_labels.txt b/resources/kws/labels/ds_cnn_labels.txt deleted file mode 100644 index ba41645..0000000 --- a/resources/kws/labels/ds_cnn_labels.txt +++ /dev/null @@ -1,12 +0,0 @@ -_silence_ -_unknown_ -yes -no -up -down -left -right -on -off -stop -go \ No newline at end of file diff --git a/resources/kws/labels/micronet_kws_labels.txt b/resources/kws/labels/micronet_kws_labels.txt new file mode 100644 index 0000000..7ad7488 --- /dev/null +++ b/resources/kws/labels/micronet_kws_labels.txt @@ -0,0 +1,12 @@ +down +go +left +no +off +on +right +stop +up +yes +_silence_ +_unknown_ diff --git a/resources/kws_asr/labels/ds_cnn_labels.txt b/resources/kws_asr/labels/ds_cnn_labels.txt deleted file mode 100644 index ba41645..0000000 --- a/resources/kws_asr/labels/ds_cnn_labels.txt +++ /dev/null @@ -1,12 +0,0 @@ -_silence_ -_unknown_ -yes -no -up -down -left -right -on -off -stop -go \ No newline at end of file diff --git a/resources/kws_asr/labels/micronet_kws_labels.txt b/resources/kws_asr/labels/micronet_kws_labels.txt new file mode 100644 index 0000000..7ad7488 --- /dev/null +++ b/resources/kws_asr/labels/micronet_kws_labels.txt @@ -0,0 +1,12 @@ +down +go +left +no +off +on +right +stop +up +yes +_silence_ +_unknown_ diff --git a/set_up_default_resources.py b/set_up_default_resources.py index 91007e4..d244213 100755 --- a/set_up_default_resources.py +++ b/set_up_default_resources.py @@ -56,12 +56,12 @@ json_uc_res = [{ }, { "use_case_name": "kws", - "resources": [{"name": "ds_cnn_clustered_int8.tflite", - "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/ds_cnn_clustered_int8.tflite"}, - {"name": "ifm0.npy", - "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_input/input_2/0.npy"}, + "resources": [{"name": "ifm0.npy", + "url": "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/testing_input/input/0.npy"}, {"name": "ofm0.npy", - "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_output/Identity/0.npy"}] + "url": "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/testing_output/Identity/0.npy"}, + {"name": "kws_micronet_m.tflite", + "url": " https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/kws_micronet_m.tflite"}] }, { "use_case_name": "vww", @@ -80,12 +80,12 @@ json_uc_res = [{ "url": "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_input/input_2_int8/0.npy"}, {"sub_folder": "asr", "name": "ofm0.npy", "url": "https://github.com/ARM-software/ML-zoo/raw/1a92aa08c0de49a7304e0a7f3f59df6f4fd33ac8/models/speech_recognition/wav2letter/tflite_pruned_int8/testing_output/Identity_int8/0.npy"}, - {"name": "ds_cnn_clustered_int8.tflite", - "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/ds_cnn_clustered_int8.tflite"}, {"sub_folder": "kws", "name": "ifm0.npy", - "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_input/input_2/0.npy"}, + "url": "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/testing_input/input/0.npy"}, {"sub_folder": "kws", "name": "ofm0.npy", - "url": "https://github.com/ARM-software/ML-zoo/raw/68b5fbc77ed28e67b2efc915997ea4477c1d9d5b/models/keyword_spotting/ds_cnn_large/tflite_clustered_int8/testing_output/Identity/0.npy"}] + "url": "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/testing_output/Identity/0.npy"}, + {"name": "kws_micronet_m.tflite", + "url": "https://github.com/ARM-software/ML-zoo/raw/9f506fe52b39df545f0e6c5ff9223f671bc5ae00/models/keyword_spotting/micronet_medium/tflite_int8/kws_micronet_m.tflite"}] }, { "use_case_name": "noise_reduction", @@ -303,8 +303,8 @@ def set_up_resources(run_vela_on_models: bool = False, # 3. Run vela on models in resources_downloaded # New models will have same name with '_vela' appended. # For example: - # original model: ds_cnn_clustered_int8.tflite - # after vela model: ds_cnn_clustered_int8_vela_H128.tflite + # original model: kws_micronet_m.tflite + # after vela model: kws_micronet_m_vela_H128.tflite # # Note: To avoid to run vela twice on the same model, it's supposed that # downloaded model names don't contain the 'vela' word. diff --git a/source/application/main/Classifier.cc b/source/application/main/Classifier.cc index c5519fb..a6ff532 100644 --- a/source/application/main/Classifier.cc +++ b/source/application/main/Classifier.cc @@ -24,61 +24,40 @@ #include #include #include +#include "PlatformMath.hpp" namespace arm { namespace app { - template - void SetVectorResults(std::set>& topNSet, + void Classifier::SetVectorResults(std::set>& topNSet, std::vector& vecResults, - TfLiteTensor* tensor, - const std::vector & labels) { - - /* For getting the floating point values, we need quantization parameters. */ - QuantParams quantParams = GetTensorQuantParams(tensor); + const std::vector & labels) + { /* Reset the iterator to the largest element - use reverse iterator. */ - auto topNIter = topNSet.rbegin(); - for (size_t i = 0; i < vecResults.size() && topNIter != topNSet.rend(); ++i, ++topNIter) { - T score = topNIter->first; - vecResults[i].m_normalisedVal = quantParams.scale * (score - quantParams.offset); - vecResults[i].m_label = labels[topNIter->second]; - vecResults[i].m_labelIdx = topNIter->second; - } - } - - template<> - void SetVectorResults(std::set>& topNSet, - std::vector& vecResults, - TfLiteTensor* tensor, - const std::vector & labels) { - UNUSED(tensor); - /* Reset the iterator to the largest element - use reverse iterator. */ auto topNIter = topNSet.rbegin(); for (size_t i = 0; i < vecResults.size() && topNIter != topNSet.rend(); ++i, ++topNIter) { vecResults[i].m_normalisedVal = topNIter->first; vecResults[i].m_label = labels[topNIter->second]; vecResults[i].m_labelIdx = topNIter->second; } - } - template - bool Classifier::GetTopNResults(TfLiteTensor* tensor, + bool Classifier::GetTopNResults(const std::vector& tensor, std::vector& vecResults, uint32_t topNCount, const std::vector & labels) { - std::set> sortedSet; + + std::set> sortedSet; /* NOTE: inputVec's size verification against labels should be * checked by the calling/public function. */ - T* tensorData = tflite::GetTensorData(tensor); /* Set initial elements. */ for (uint32_t i = 0; i < topNCount; ++i) { - sortedSet.insert({tensorData[i], i}); + sortedSet.insert({tensor[i], i}); } /* Initialise iterator. */ @@ -86,33 +65,26 @@ namespace app { /* Scan through the rest of elements with compare operations. */ for (uint32_t i = topNCount; i < labels.size(); ++i) { - if (setFwdIter->first < tensorData[i]) { + if (setFwdIter->first < tensor[i]) { sortedSet.erase(*setFwdIter); - sortedSet.insert({tensorData[i], i}); + sortedSet.insert({tensor[i], i}); setFwdIter = sortedSet.begin(); } } /* Final results' container. */ vecResults = std::vector(topNCount); - - SetVectorResults(sortedSet, vecResults, tensor, labels); + SetVectorResults(sortedSet, vecResults, labels); return true; } - template bool Classifier::GetTopNResults(TfLiteTensor* tensor, - std::vector& vecResults, - uint32_t topNCount, const std::vector & labels); - - template bool Classifier::GetTopNResults(TfLiteTensor* tensor, - std::vector& vecResults, - uint32_t topNCount, const std::vector & labels); - bool Classifier::GetClassificationResults( TfLiteTensor* outputTensor, std::vector& vecResults, - const std::vector & labels, uint32_t topNCount) + const std::vector & labels, + uint32_t topNCount, + bool useSoftmax) { if (outputTensor == nullptr) { printf_err("Output vector is null pointer.\n"); @@ -120,7 +92,7 @@ namespace app { } uint32_t totalOutputSize = 1; - for (int inputDim = 0; inputDim < outputTensor->dims->size; inputDim++){ + for (int inputDim = 0; inputDim < outputTensor->dims->size; inputDim++) { totalOutputSize *= outputTensor->dims->data[inputDim]; } @@ -139,22 +111,52 @@ namespace app { bool resultState; vecResults.clear(); - /* Get the top N results. */ + /* De-Quantize Output Tensor */ + QuantParams quantParams = GetTensorQuantParams(outputTensor); + + /* Floating point tensor data to be populated + * NOTE: The assumption here is that the output tensor size isn't too + * big and therefore, there's neglibible impact on heap usage. */ + std::vector tensorData(totalOutputSize); + + /* Populate the floating point buffer */ switch (outputTensor->type) { - case kTfLiteUInt8: - resultState = GetTopNResults(outputTensor, vecResults, topNCount, labels); + case kTfLiteUInt8: { + uint8_t *tensor_buffer = tflite::GetTensorData(outputTensor); + for (size_t i = 0; i < totalOutputSize; ++i) { + tensorData[i] = quantParams.scale * + (static_cast(tensor_buffer[i]) - quantParams.offset); + } break; - case kTfLiteInt8: - resultState = GetTopNResults(outputTensor, vecResults, topNCount, labels); + } + case kTfLiteInt8: { + int8_t *tensor_buffer = tflite::GetTensorData(outputTensor); + for (size_t i = 0; i < totalOutputSize; ++i) { + tensorData[i] = quantParams.scale * + (static_cast(tensor_buffer[i]) - quantParams.offset); + } break; - case kTfLiteFloat32: - resultState = GetTopNResults(outputTensor, vecResults, topNCount, labels); + } + case kTfLiteFloat32: { + float *tensor_buffer = tflite::GetTensorData(outputTensor); + for (size_t i = 0; i < totalOutputSize; ++i) { + tensorData[i] = tensor_buffer[i]; + } break; + } default: - printf_err("Tensor type %s not supported by classifier\n", TfLiteTypeGetName(outputTensor->type)); + printf_err("Tensor type %s not supported by classifier\n", + TfLiteTypeGetName(outputTensor->type)); return false; } + if (useSoftmax) { + math::MathUtils::SoftmaxF32(tensorData); + } + + /* Get the top N results. */ + resultState = GetTopNResults(tensorData, vecResults, topNCount, labels); + if (!resultState) { printf_err("Failed to get top N results set\n"); return false; @@ -162,6 +164,5 @@ namespace app { return true; } - } /* namespace app */ } /* namespace arm */ \ No newline at end of file diff --git a/source/application/main/PlatformMath.cc b/source/application/main/PlatformMath.cc index 0b8882a..26b4b72 100644 --- a/source/application/main/PlatformMath.cc +++ b/source/application/main/PlatformMath.cc @@ -15,6 +15,8 @@ * limitations under the License. */ #include "PlatformMath.hpp" +#include +#include #if 0 == ARM_DSP_AVAILABLE #include @@ -290,6 +292,24 @@ namespace math { return true; } + void MathUtils::SoftmaxF32(std::vector& vec) + { + /* Fix for numerical stability and apply exp. */ + auto start = vec.begin(); + auto end = vec.end(); + + float maxValue = *std::max_element(start, end); + for (auto it = start; it != end; ++it) { + *it = std::exp((*it) - maxValue); + } + + float sumExp = std::accumulate(start, end, 0.0f); + + for (auto it = start; it != end; ++it) { + *it = (*it)/sumExp; + } + } + } /* namespace math */ } /* namespace app */ } /* namespace arm */ diff --git a/source/application/main/include/Classifier.hpp b/source/application/main/include/Classifier.hpp index 3ee3148..d899e8e 100644 --- a/source/application/main/include/Classifier.hpp +++ b/source/application/main/include/Classifier.hpp @@ -42,18 +42,33 @@ namespace app { * populated by this function. * @param[in] labels Labels vector to match classified classes. * @param[in] topNCount Number of top classifications to pick. Default is 1. + * @param[in] useSoftmax Whether Softmax normalisation should be applied to output. Default is false. * @return true if successful, false otherwise. **/ + virtual bool GetClassificationResults( TfLiteTensor* outputTensor, std::vector& vecResults, - const std::vector & labels, uint32_t topNCount); + const std::vector & labels, uint32_t topNCount, + bool use_softmax = false); + + /** + * @brief Populate the elements of the Classification Result object. + * @param[in] topNSet Ordered set of top 5 output class scores and labels. + * @param[out] vecResults A vector of classification results. + * populated by this function. + * @param[in] labels Labels vector to match classified classes. + **/ + + void SetVectorResults( + std::set>& topNSet, + std::vector& vecResults, + const std::vector & labels); private: /** * @brief Utility function that gets the top N classification results from the * output vector. - * @tparam T value type * @param[in] tensor Inference output tensor from an NN model. * @param[out] vecResults A vector of classification results * populated by this function. @@ -61,8 +76,8 @@ namespace app { * @param[in] labels Labels vector to match classified classes. * @return true if successful, false otherwise. **/ - template - bool GetTopNResults(TfLiteTensor* tensor, + + bool GetTopNResults(const std::vector& tensor, std::vector& vecResults, uint32_t topNCount, const std::vector & labels); diff --git a/source/application/main/include/PlatformMath.hpp b/source/application/main/include/PlatformMath.hpp index 6804025..fdb51b2 100644 --- a/source/application/main/include/PlatformMath.hpp +++ b/source/application/main/include/PlatformMath.hpp @@ -161,7 +161,14 @@ namespace math { float* ptrDst, const uint32_t dstLen); + /** + * @brief Scales output scores for an arbitrary number of classes so + * that they sum to 1, allowing output to be expressed as a probability. + * @param[in] vector Vector of floats modified in-place + */ + static void SoftmaxF32(std::vector& vec); }; + } /* namespace math */ } /* namespace app */ } /* namespace arm */ diff --git a/source/use_case/asr/include/AsrClassifier.hpp b/source/use_case/asr/include/AsrClassifier.hpp index 2c97a39..67a200e 100644 --- a/source/use_case/asr/include/AsrClassifier.hpp +++ b/source/use_case/asr/include/AsrClassifier.hpp @@ -32,12 +32,13 @@ namespace app { * populated by this function. * @param[in] labels Labels vector to match classified classes * @param[in] topNCount Number of top classifications to pick. + * @param[in] use_softmax Whether softmax scaling should be applied to model output. * @return true if successful, false otherwise. **/ bool GetClassificationResults( TfLiteTensor* outputTensor, std::vector& vecResults, - const std::vector & labels, uint32_t topNCount) override; + const std::vector & labels, uint32_t topNCount, bool use_softmax = false) override; private: /** diff --git a/source/use_case/asr/src/AsrClassifier.cc b/source/use_case/asr/src/AsrClassifier.cc index c18bd88..a715068 100644 --- a/source/use_case/asr/src/AsrClassifier.cc +++ b/source/use_case/asr/src/AsrClassifier.cc @@ -73,8 +73,9 @@ template bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tenso bool arm::app::AsrClassifier::GetClassificationResults( TfLiteTensor* outputTensor, std::vector& vecResults, - const std::vector & labels, uint32_t topNCount) + const std::vector & labels, uint32_t topNCount, bool use_softmax) { + UNUSED(use_softmax); vecResults.clear(); constexpr int minTensorDims = static_cast( diff --git a/source/use_case/kws/include/DsCnnMfcc.hpp b/source/use_case/kws/include/DsCnnMfcc.hpp deleted file mode 100644 index 3f681af..0000000 --- a/source/use_case/kws/include/DsCnnMfcc.hpp +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_DSCNN_MFCC_HPP -#define KWS_DSCNN_MFCC_HPP - -#include "Mfcc.hpp" - -namespace arm { -namespace app { -namespace audio { - - /* Class to provide DS-CNN specific MFCC calculation requirements. */ - class DsCnnMFCC : public MFCC { - - public: - static constexpr uint32_t ms_defaultSamplingFreq = 16000; - static constexpr uint32_t ms_defaultNumFbankBins = 40; - static constexpr uint32_t ms_defaultMelLoFreq = 20; - static constexpr uint32_t ms_defaultMelHiFreq = 4000; - static constexpr bool ms_defaultUseHtkMethod = true; - - explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen) - : MFCC(MfccParams( - ms_defaultSamplingFreq, ms_defaultNumFbankBins, - ms_defaultMelLoFreq, ms_defaultMelHiFreq, - numFeats, frameLen, ms_defaultUseHtkMethod)) - {} - DsCnnMFCC() = delete; - ~DsCnnMFCC() = default; - }; - -} /* namespace audio */ -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_DSCNN_MFCC_HPP */ \ No newline at end of file diff --git a/source/use_case/kws/include/DsCnnModel.hpp b/source/use_case/kws/include/DsCnnModel.hpp deleted file mode 100644 index a1a45cd..0000000 --- a/source/use_case/kws/include/DsCnnModel.hpp +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_DSCNNMODEL_HPP -#define KWS_DSCNNMODEL_HPP - -#include "Model.hpp" - -extern const int g_FrameLength; -extern const int g_FrameStride; -extern const float g_ScoreThreshold; - -namespace arm { -namespace app { - - class DsCnnModel : public Model { - public: - /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 2; - static constexpr uint32_t ms_inputColsIdx = 3; - static constexpr uint32_t ms_outputRowsIdx = 2; - static constexpr uint32_t ms_outputColsIdx = 3; - - protected: - /** @brief Gets the reference to op resolver interface class. */ - const tflite::MicroOpResolver& GetOpResolver() override; - - /** @brief Adds operations to the op resolver instance. */ - bool EnlistOperations() override; - - const uint8_t* ModelPointer() override; - - size_t ModelSize() override; - - private: - /* Maximum number of individual operations that can be enlisted. */ - static constexpr int ms_maxOpCnt = 8; - - /* A mutable op resolver instance. */ - tflite::MicroMutableOpResolver m_opResolver; - }; - -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_DSCNNMODEL_HPP */ diff --git a/source/use_case/kws/include/MicroNetKwsMfcc.hpp b/source/use_case/kws/include/MicroNetKwsMfcc.hpp new file mode 100644 index 0000000..b2565a3 --- /dev/null +++ b/source/use_case/kws/include/MicroNetKwsMfcc.hpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_MICRONET_MFCC_HPP +#define KWS_MICRONET_MFCC_HPP + +#include "Mfcc.hpp" + +namespace arm { +namespace app { +namespace audio { + + /* Class to provide MicroNet specific MFCC calculation requirements. */ + class MicroNetKwsMFCC : public MFCC { + + public: + static constexpr uint32_t ms_defaultSamplingFreq = 16000; + static constexpr uint32_t ms_defaultNumFbankBins = 40; + static constexpr uint32_t ms_defaultMelLoFreq = 20; + static constexpr uint32_t ms_defaultMelHiFreq = 4000; + static constexpr bool ms_defaultUseHtkMethod = true; + + explicit MicroNetKwsMFCC(const size_t numFeats, const size_t frameLen) + : MFCC(MfccParams( + ms_defaultSamplingFreq, ms_defaultNumFbankBins, + ms_defaultMelLoFreq, ms_defaultMelHiFreq, + numFeats, frameLen, ms_defaultUseHtkMethod)) + {} + MicroNetKwsMFCC() = delete; + ~MicroNetKwsMFCC() = default; + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_MICRONET_MFCC_HPP */ \ No newline at end of file diff --git a/source/use_case/kws/include/MicroNetKwsModel.hpp b/source/use_case/kws/include/MicroNetKwsModel.hpp new file mode 100644 index 0000000..3259c45 --- /dev/null +++ b/source/use_case/kws/include/MicroNetKwsModel.hpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_MICRONETMODEL_HPP +#define KWS_MICRONETMODEL_HPP + +#include "Model.hpp" + +extern const int g_FrameLength; +extern const int g_FrameStride; +extern const float g_ScoreThreshold; + +namespace arm { +namespace app { + + class MicroNetKwsModel : public Model { + public: + /* Indices for the expected model - based on input and output tensor shapes */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_outputRowsIdx = 2; + static constexpr uint32_t ms_outputColsIdx = 3; + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int ms_maxOpCnt = 7; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_MICRONETMODEL_HPP */ diff --git a/source/use_case/kws/src/DsCnnModel.cc b/source/use_case/kws/src/DsCnnModel.cc deleted file mode 100644 index 4edfc04..0000000 --- a/source/use_case/kws/src/DsCnnModel.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "DsCnnModel.hpp" - -#include "hal.h" - -const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver() -{ - return this->m_opResolver; -} - -bool arm::app::DsCnnModel::EnlistOperations() -{ - this->m_opResolver.AddReshape(); - this->m_opResolver.AddAveragePool2D(); - this->m_opResolver.AddConv2D(); - this->m_opResolver.AddDepthwiseConv2D(); - this->m_opResolver.AddFullyConnected(); - this->m_opResolver.AddRelu(); - this->m_opResolver.AddSoftmax(); - -#if defined(ARM_NPU) - if (kTfLiteOk == this->m_opResolver.AddEthosU()) { - info("Added %s support to op resolver\n", - tflite::GetString_ETHOSU()); - } else { - printf_err("Failed to add Arm NPU support to op resolver."); - return false; - } -#endif /* ARM_NPU */ - return true; -} - -extern uint8_t* GetModelPointer(); -const uint8_t* arm::app::DsCnnModel::ModelPointer() -{ - return GetModelPointer(); -} - -extern size_t GetModelLen(); -size_t arm::app::DsCnnModel::ModelSize() -{ - return GetModelLen(); -} \ No newline at end of file diff --git a/source/use_case/kws/src/MainLoop.cc b/source/use_case/kws/src/MainLoop.cc index c683e71..bde246b 100644 --- a/source/use_case/kws/src/MainLoop.cc +++ b/source/use_case/kws/src/MainLoop.cc @@ -16,7 +16,7 @@ */ #include "InputFiles.hpp" /* For input audio clips. */ #include "Classifier.hpp" /* Classifier. */ -#include "DsCnnModel.hpp" /* Model class for running inference. */ +#include "MicroNetKwsModel.hpp" /* Model class for running inference. */ #include "hal.h" /* Brings in platform definitions. */ #include "Labels.hpp" /* For label strings. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ @@ -49,7 +49,7 @@ static void DisplayMenu() void main_loop(hal_platform& platform) { - arm::app::DsCnnModel model; /* Model wrapper object. */ + arm::app::MicroNetKwsModel model; /* Model wrapper object. */ /* Load the model. */ if (!model.Init()) { diff --git a/source/use_case/kws/src/MicroNetKwsModel.cc b/source/use_case/kws/src/MicroNetKwsModel.cc new file mode 100644 index 0000000..48a9b8c --- /dev/null +++ b/source/use_case/kws/src/MicroNetKwsModel.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MicroNetKwsModel.hpp" + +#include "hal.h" + +const tflite::MicroOpResolver& arm::app::MicroNetKwsModel::GetOpResolver() +{ + return this->m_opResolver; +} + +bool arm::app::MicroNetKwsModel::EnlistOperations() +{ + this->m_opResolver.AddReshape(); + this->m_opResolver.AddAveragePool2D(); + this->m_opResolver.AddConv2D(); + this->m_opResolver.AddDepthwiseConv2D(); + this->m_opResolver.AddFullyConnected(); + this->m_opResolver.AddRelu(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +extern uint8_t* GetModelPointer(); +const uint8_t* arm::app::MicroNetKwsModel::ModelPointer() +{ + return GetModelPointer(); +} + +extern size_t GetModelLen(); +size_t arm::app::MicroNetKwsModel::ModelSize() +{ + return GetModelLen(); +} \ No newline at end of file diff --git a/source/use_case/kws/src/UseCaseHandler.cc b/source/use_case/kws/src/UseCaseHandler.cc index 3d95753..8085af7 100644 --- a/source/use_case/kws/src/UseCaseHandler.cc +++ b/source/use_case/kws/src/UseCaseHandler.cc @@ -18,9 +18,9 @@ #include "InputFiles.hpp" #include "Classifier.hpp" -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include "AudioUtils.hpp" #include "UseCaseCommonUtils.hpp" #include "KwsResult.hpp" @@ -59,7 +59,7 @@ namespace app { * @return Function to be called providing audio sample and sliding window index. */ static std::function&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize); @@ -72,8 +72,8 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartX = 20; constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast( - (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)? - arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx); + (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? + arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); auto& model = ctx.Get("model"); @@ -105,10 +105,10 @@ namespace app { } TfLiteIntArray* inputShape = model.GetInputShape(0); - const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx]; - const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx]; + const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx]; + const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx]; - audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength); + audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength); mfcc.Init(); /* Deduce the data length required for 1 inference from the network parameters. */ @@ -132,7 +132,7 @@ namespace app { /* We expect to be sampling 1 second worth of data at a time. * NOTE: This is only used for time stamp calculation. */ - const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq; + const float secondsPerSample = 1.0/audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; do { platform.data_psn->clear(COLOR_BLACK); @@ -208,7 +208,7 @@ namespace app { std::vector classificationResult; auto& classifier = ctx.Get("classifier"); classifier.GetClassificationResults(outputTensor, classificationResult, - ctx.Get&>("labels"), 1); + ctx.Get&>("labels"), 1, true); results.emplace_back(kws::KwsResult(classificationResult, audioDataSlider.Index() * secondsPerSample * audioDataStride, @@ -240,7 +240,6 @@ namespace app { return true; } - static bool PresentInferenceResult(hal_platform& platform, const std::vector& results) { @@ -259,7 +258,6 @@ namespace app { std::string topKeyword{""}; float score = 0.f; - if (!results[i].m_resultVec.empty()) { topKeyword = results[i].m_resultVec[0].m_label; score = results[i].m_resultVec[0].m_normalisedVal; @@ -366,7 +364,7 @@ namespace app { static std::function&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) { std::function&, size_t, bool, size_t)> mfccFeatureCalc; diff --git a/source/use_case/kws/usecase.cmake b/source/use_case/kws/usecase.cmake index 34e39e4..9f3736e 100644 --- a/source/use_case/kws/usecase.cmake +++ b/source/use_case/kws/usecase.cmake @@ -20,7 +20,7 @@ USER_OPTION(${use_case}_FILE_PATH "Directory with custom WAV input files, or pat PATH_OR_FILE) USER_OPTION(${use_case}_LABELS_TXT_FILE "Labels' txt file for the chosen model." - ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/ds_cnn_labels.txt + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/micronet_kws_labels.txt FILEPATH) USER_OPTION(${use_case}_AUDIO_RATE "Specify the target sampling rate. Default is 16000." @@ -48,7 +48,7 @@ USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples STRING) USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD "Specify the score threshold [0.0, 1.0) that must be applied to the inference results for a label to be deemed valid." - 0.9 + 0.7 STRING) # Generate input files @@ -73,10 +73,11 @@ USER_OPTION(${use_case}_ACTIVATION_BUF_SZ "Activation buffer size for the chosen 0x00100000 STRING) + if (ETHOS_U_NPU_ENABLED) - set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) + set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/kws_micronet_m_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) else() - set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8.tflite) + set(DEFAULT_MODEL_PATH ${DEFAULT_MODEL_DIR}/kws_micronet_m.tflite) endif() set(EXTRA_MODEL_CODE diff --git a/source/use_case/kws_asr/include/AsrClassifier.hpp b/source/use_case/kws_asr/include/AsrClassifier.hpp index 7dbb6e9..6ab9685 100644 --- a/source/use_case/kws_asr/include/AsrClassifier.hpp +++ b/source/use_case/kws_asr/include/AsrClassifier.hpp @@ -32,12 +32,14 @@ namespace app { * populated by this function. * @param[in] labels Labels vector to match classified classes * @param[in] topNCount Number of top classifications to pick. + * @param[in] use_softmax Whether softmax scaling should be applied to model output. * @return true if successful, false otherwise. **/ bool GetClassificationResults( TfLiteTensor* outputTensor, std::vector& vecResults, - const std::vector & labels, uint32_t topNCount) override; + const std::vector & labels, uint32_t topNCount, + bool use_softmax = false) override; private: diff --git a/source/use_case/kws_asr/include/DsCnnMfcc.hpp b/source/use_case/kws_asr/include/DsCnnMfcc.hpp deleted file mode 100644 index c97dd9d..0000000 --- a/source/use_case/kws_asr/include/DsCnnMfcc.hpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_DSCNN_MFCC_HPP -#define KWS_ASR_DSCNN_MFCC_HPP - -#include "Mfcc.hpp" - -namespace arm { -namespace app { -namespace audio { - - /* Class to provide DS-CNN specific MFCC calculation requirements. */ - class DsCnnMFCC : public MFCC { - - public: - static constexpr uint32_t ms_defaultSamplingFreq = 16000; - static constexpr uint32_t ms_defaultNumFbankBins = 40; - static constexpr uint32_t ms_defaultMelLoFreq = 20; - static constexpr uint32_t ms_defaultMelHiFreq = 4000; - static constexpr bool ms_defaultUseHtkMethod = true; - - - explicit DsCnnMFCC(const size_t numFeats, const size_t frameLen) - : MFCC(MfccParams( - ms_defaultSamplingFreq, ms_defaultNumFbankBins, - ms_defaultMelLoFreq, ms_defaultMelHiFreq, - numFeats, frameLen, ms_defaultUseHtkMethod)) - {} - DsCnnMFCC() = delete; - ~DsCnnMFCC() = default; - }; - -} /* namespace audio */ -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_ASR_DSCNN_MFCC_HPP */ diff --git a/source/use_case/kws_asr/include/DsCnnModel.hpp b/source/use_case/kws_asr/include/DsCnnModel.hpp deleted file mode 100644 index 92d96b9..0000000 --- a/source/use_case/kws_asr/include/DsCnnModel.hpp +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef KWS_ASR_DSCNNMODEL_HPP -#define KWS_ASR_DSCNNMODEL_HPP - -#include "Model.hpp" - -namespace arm { -namespace app { -namespace kws { - extern const int g_FrameLength; - extern const int g_FrameStride; - extern const float g_ScoreThreshold; - extern const uint32_t g_NumMfcc; - extern const uint32_t g_NumAudioWins; -} /* namespace kws */ -} /* namespace app */ -} /* namespace arm */ - -namespace arm { -namespace app { - - class DsCnnModel : public Model { - public: - /* Indices for the expected model - based on input and output tensor shapes */ - static constexpr uint32_t ms_inputRowsIdx = 2; - static constexpr uint32_t ms_inputColsIdx = 3; - static constexpr uint32_t ms_outputRowsIdx = 2; - static constexpr uint32_t ms_outputColsIdx = 3; - - protected: - /** @brief Gets the reference to op resolver interface class. */ - const tflite::MicroOpResolver& GetOpResolver() override; - - /** @brief Adds operations to the op resolver instance. */ - bool EnlistOperations() override; - - const uint8_t* ModelPointer() override; - - size_t ModelSize() override; - - private: - /* Maximum number of individual operations that can be enlisted. */ - static constexpr int ms_maxOpCnt = 10; - - /* A mutable op resolver instance. */ - tflite::MicroMutableOpResolver m_opResolver; - }; - -} /* namespace app */ -} /* namespace arm */ - -#endif /* KWS_DSCNNMODEL_HPP */ diff --git a/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp new file mode 100644 index 0000000..43bd390 --- /dev/null +++ b/source/use_case/kws_asr/include/MicroNetKwsMfcc.hpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_MICRONET_MFCC_HPP +#define KWS_ASR_MICRONET_MFCC_HPP + +#include "Mfcc.hpp" + +namespace arm { +namespace app { +namespace audio { + + /* Class to provide MicroNet specific MFCC calculation requirements. */ + class MicroNetMFCC : public MFCC { + + public: + static constexpr uint32_t ms_defaultSamplingFreq = 16000; + static constexpr uint32_t ms_defaultNumFbankBins = 40; + static constexpr uint32_t ms_defaultMelLoFreq = 20; + static constexpr uint32_t ms_defaultMelHiFreq = 4000; + static constexpr bool ms_defaultUseHtkMethod = true; + + + explicit MicroNetMFCC(const size_t numFeats, const size_t frameLen) + : MFCC(MfccParams( + ms_defaultSamplingFreq, ms_defaultNumFbankBins, + ms_defaultMelLoFreq, ms_defaultMelHiFreq, + numFeats, frameLen, ms_defaultUseHtkMethod)) + {} + MicroNetMFCC() = delete; + ~MicroNetMFCC() = default; + }; + +} /* namespace audio */ +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_MICRONET_MFCC_HPP */ diff --git a/source/use_case/kws_asr/include/MicroNetKwsModel.hpp b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp new file mode 100644 index 0000000..22cf916 --- /dev/null +++ b/source/use_case/kws_asr/include/MicroNetKwsModel.hpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KWS_ASR_MICRONETMODEL_HPP +#define KWS_ASR_MICRONETMODEL_HPP + +#include "Model.hpp" + +namespace arm { +namespace app { +namespace kws { + extern const int g_FrameLength; + extern const int g_FrameStride; + extern const float g_ScoreThreshold; + extern const uint32_t g_NumMfcc; + extern const uint32_t g_NumAudioWins; +} /* namespace kws */ +} /* namespace app */ +} /* namespace arm */ + +namespace arm { +namespace app { + class MicroNetKwsModel : public Model { + public: + /* Indices for the expected model - based on input and output tensor shapes */ + static constexpr uint32_t ms_inputRowsIdx = 1; + static constexpr uint32_t ms_inputColsIdx = 2; + static constexpr uint32_t ms_outputRowsIdx = 2; + static constexpr uint32_t ms_outputColsIdx = 3; + + protected: + /** @brief Gets the reference to op resolver interface class. */ + const tflite::MicroOpResolver& GetOpResolver() override; + + /** @brief Adds operations to the op resolver instance. */ + bool EnlistOperations() override; + + const uint8_t* ModelPointer() override; + + size_t ModelSize() override; + + private: + /* Maximum number of individual operations that can be enlisted. */ + static constexpr int ms_maxOpCnt = 7; + + /* A mutable op resolver instance. */ + tflite::MicroMutableOpResolver m_opResolver; + }; + +} /* namespace app */ +} /* namespace arm */ + +#endif /* KWS_ASR_MICRONETMODEL_HPP */ diff --git a/source/use_case/kws_asr/src/AsrClassifier.cc b/source/use_case/kws_asr/src/AsrClassifier.cc index 57d5058..3f9cd7b 100644 --- a/source/use_case/kws_asr/src/AsrClassifier.cc +++ b/source/use_case/kws_asr/src/AsrClassifier.cc @@ -73,8 +73,9 @@ template bool arm::app::AsrClassifier::GetTopResults(TfLiteTensor* tenso bool arm::app::AsrClassifier::GetClassificationResults( TfLiteTensor* outputTensor, std::vector& vecResults, - const std::vector & labels, uint32_t topNCount) + const std::vector & labels, uint32_t topNCount, bool use_softmax) { + UNUSED(use_softmax); vecResults.clear(); constexpr int minTensorDims = static_cast( diff --git a/source/use_case/kws_asr/src/DsCnnModel.cc b/source/use_case/kws_asr/src/DsCnnModel.cc deleted file mode 100644 index 71d4ceb..0000000 --- a/source/use_case/kws_asr/src/DsCnnModel.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "DsCnnModel.hpp" - -#include "hal.h" - -namespace arm { -namespace app { -namespace kws { - extern uint8_t* GetModelPointer(); - extern size_t GetModelLen(); -} /* namespace kws */ -} /* namespace app */ -} /* namespace arm */ - -const tflite::MicroOpResolver& arm::app::DsCnnModel::GetOpResolver() -{ - return this->m_opResolver; -} - -bool arm::app::DsCnnModel::EnlistOperations() -{ - this->m_opResolver.AddAveragePool2D(); - this->m_opResolver.AddConv2D(); - this->m_opResolver.AddDepthwiseConv2D(); - this->m_opResolver.AddFullyConnected(); - this->m_opResolver.AddRelu(); - this->m_opResolver.AddSoftmax(); - this->m_opResolver.AddQuantize(); - this->m_opResolver.AddDequantize(); - this->m_opResolver.AddReshape(); - -#if defined(ARM_NPU) - if (kTfLiteOk == this->m_opResolver.AddEthosU()) { - info("Added %s support to op resolver\n", - tflite::GetString_ETHOSU()); - } else { - printf_err("Failed to add Arm NPU support to op resolver."); - return false; - } -#endif /* ARM_NPU */ - return true; -} - -const uint8_t* arm::app::DsCnnModel::ModelPointer() -{ - return arm::app::kws::GetModelPointer(); -} - -size_t arm::app::DsCnnModel::ModelSize() -{ - return arm::app::kws::GetModelLen(); -} \ No newline at end of file diff --git a/source/use_case/kws_asr/src/MainLoop.cc b/source/use_case/kws_asr/src/MainLoop.cc index d5a2c2b..30cb084 100644 --- a/source/use_case/kws_asr/src/MainLoop.cc +++ b/source/use_case/kws_asr/src/MainLoop.cc @@ -16,11 +16,11 @@ */ #include "hal.h" /* Brings in platform definitions. */ #include "InputFiles.hpp" /* For input images. */ -#include "Labels_dscnn.hpp" /* For DS-CNN label strings. */ +#include "Labels_micronetkws.hpp" /* For MicroNetKws label strings. */ #include "Labels_wav2letter.hpp" /* For Wav2Letter label strings. */ #include "Classifier.hpp" /* KWS classifier. */ #include "AsrClassifier.hpp" /* ASR classifier. */ -#include "DsCnnModel.hpp" /* KWS model class for running inference. */ +#include "MicroNetKwsModel.hpp" /* KWS model class for running inference. */ #include "Wav2LetterModel.hpp" /* ASR model class for running inference. */ #include "UseCaseCommonUtils.hpp" /* Utils functions. */ #include "UseCaseHandler.hpp" /* Handlers for different user options. */ @@ -69,7 +69,7 @@ static uint32_t GetOutputInnerLen(const arm::app::Model& model, void main_loop(hal_platform& platform) { /* Model wrapper objects. */ - arm::app::DsCnnModel kwsModel; + arm::app::MicroNetKwsModel kwsModel; arm::app::Wav2LetterModel asrModel; /* Load the models. */ @@ -81,7 +81,7 @@ void main_loop(hal_platform& platform) /* Initialise the asr model using the same allocator from KWS * to re-use the tensor arena. */ if (!asrModel.Init(kwsModel.GetAllocator())) { - printf_err("Failed to initalise ASR model\n"); + printf_err("Failed to initialise ASR model\n"); return; } @@ -137,7 +137,7 @@ void main_loop(hal_platform& platform) caseContext.Set&>("kwslabels", kwsLabels); /* Index of the kws outputs we trigger ASR on. */ - caseContext.Set("keywordindex", 2); + caseContext.Set("keywordindex", 9 ); /* Loop. */ bool executionSuccessful = true; diff --git a/source/use_case/kws_asr/src/MicroNetKwsModel.cc b/source/use_case/kws_asr/src/MicroNetKwsModel.cc new file mode 100644 index 0000000..4b44580 --- /dev/null +++ b/source/use_case/kws_asr/src/MicroNetKwsModel.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MicroNetKwsModel.hpp" + +#include "hal.h" + +namespace arm { +namespace app { +namespace kws { + extern uint8_t* GetModelPointer(); + extern size_t GetModelLen(); +} /* namespace kws */ +} /* namespace app */ +} /* namespace arm */ + +const tflite::MicroOpResolver& arm::app::MicroNetKwsModel::GetOpResolver() +{ + return this->m_opResolver; +} + +bool arm::app::MicroNetKwsModel::EnlistOperations() +{ + this->m_opResolver.AddAveragePool2D(); + this->m_opResolver.AddConv2D(); + this->m_opResolver.AddDepthwiseConv2D(); + this->m_opResolver.AddFullyConnected(); + this->m_opResolver.AddRelu(); + this->m_opResolver.AddReshape(); + +#if defined(ARM_NPU) + if (kTfLiteOk == this->m_opResolver.AddEthosU()) { + info("Added %s support to op resolver\n", + tflite::GetString_ETHOSU()); + } else { + printf_err("Failed to add Arm NPU support to op resolver."); + return false; + } +#endif /* ARM_NPU */ + return true; +} + +const uint8_t* arm::app::MicroNetKwsModel::ModelPointer() +{ + return arm::app::kws::GetModelPointer(); +} + +size_t arm::app::MicroNetKwsModel::ModelSize() +{ + return arm::app::kws::GetModelLen(); +} \ No newline at end of file diff --git a/source/use_case/kws_asr/src/UseCaseHandler.cc b/source/use_case/kws_asr/src/UseCaseHandler.cc index 1d88ba1..c67be22 100644 --- a/source/use_case/kws_asr/src/UseCaseHandler.cc +++ b/source/use_case/kws_asr/src/UseCaseHandler.cc @@ -20,8 +20,8 @@ #include "InputFiles.hpp" #include "AudioUtils.hpp" #include "UseCaseCommonUtils.hpp" -#include "DsCnnModel.hpp" -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsModel.hpp" +#include "MicroNetKwsMfcc.hpp" #include "Classifier.hpp" #include "KwsResult.hpp" #include "Wav2LetterMfcc.hpp" @@ -77,12 +77,12 @@ namespace app { * * @param[in] mfcc MFCC feature calculator. * @param[in,out] inputTensor Input tensor pointer to store calculated features. - * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors). + * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors). * * @return function function to be called providing audio sample and sliding window index. **/ static std::function&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, + GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize); @@ -98,8 +98,8 @@ namespace app { constexpr uint32_t dataPsnTxtInfStartY = 40; constexpr int minTensorDims = static_cast( - (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)? - arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx); + (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)? + arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx); KWSOutput output; @@ -128,7 +128,7 @@ namespace app { const uint32_t kwsNumMfccFeats = ctx.Get("kwsNumMfcc"); const uint32_t kwsNumAudioWindows = ctx.Get("kwsNumAudioWins"); - audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength); + audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength); kwsMfcc.Init(); /* Deduce the data length required for 1 KWS inference from the network parameters. */ @@ -152,7 +152,7 @@ namespace app { /* We expect to be sampling 1 second worth of data at a time * NOTE: This is only used for time stamp calculation. */ - const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq; + const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq; auto currentIndex = ctx.Get("clipIndex"); @@ -230,7 +230,7 @@ namespace app { kwsClassifier.GetClassificationResults( kwsOutputTensor, kwsClassificationResult, - ctx.Get&>("kwslabels"), 1); + ctx.Get&>("kwslabels"), 1, true); kwsResults.emplace_back( kws::KwsResult( @@ -604,7 +604,7 @@ namespace app { static std::function&, int, bool, size_t)> - GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) + GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize) { std::function&, size_t, bool, size_t)> mfccFeatureCalc; diff --git a/source/use_case/kws_asr/usecase.cmake b/source/use_case/kws_asr/usecase.cmake index d8629b6..b3fe020 100644 --- a/source/use_case/kws_asr/usecase.cmake +++ b/source/use_case/kws_asr/usecase.cmake @@ -45,7 +45,7 @@ USER_OPTION(${use_case}_AUDIO_MIN_SAMPLES "Specify the minimum number of samples # Generate kws labels file: USER_OPTION(${use_case}_LABELS_TXT_FILE_KWS "Labels' txt file for the chosen model." - ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/ds_cnn_labels.txt + ${CMAKE_CURRENT_SOURCE_DIR}/resources/${use_case}/labels/micronet_kws_labels.txt FILEPATH) # Generate asr labels file: @@ -67,10 +67,10 @@ USER_OPTION(${use_case}_MODEL_SCORE_THRESHOLD_ASR "Specify the score threshold [ STRING) if (ETHOS_U_NPU_ENABLED) - set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) + set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/kws_micronet_m_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) set(DEFAULT_MODEL_PATH_ASR ${DEFAULT_MODEL_DIR}/wav2letter_pruned_int8_vela_${ETHOS_U_NPU_CONFIG_ID}.tflite) else() - set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/ds_cnn_clustered_int8.tflite) + set(DEFAULT_MODEL_PATH_KWS ${DEFAULT_MODEL_DIR}/kws_micronet_m.tflite) set(DEFAULT_MODEL_PATH_ASR ${DEFAULT_MODEL_DIR}/wav2letter_pruned_int8.tflite) endif() @@ -134,7 +134,7 @@ generate_labels_code( INPUT "${${use_case}_LABELS_TXT_FILE_KWS}" DESTINATION_SRC ${SRC_GEN_DIR} DESTINATION_HDR ${INC_GEN_DIR} - OUTPUT_FILENAME "Labels_dscnn" + OUTPUT_FILENAME "Labels_micronetkws" NAMESPACE "arm" "app" "kws" ) diff --git a/tests/common/ClassifierTests.cc b/tests/common/ClassifierTests.cc index d950304..693f744 100644 --- a/tests/common/ClassifierTests.cc +++ b/tests/common/ClassifierTests.cc @@ -35,7 +35,7 @@ void test_classifier_result(std::vector>& selectedResults } arm::app::Classifier classifier; - REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5)); + REQUIRE(classifier.GetClassificationResults(outputTensor, resultVec, labels, 5, true)); REQUIRE(5 == resultVec.size()); for (size_t i = 0; i < resultVec.size(); ++i) { @@ -50,7 +50,7 @@ TEST_CASE("Common classifier") TfLiteTensor* outputTens = nullptr; std::vector resultVec; arm::app::Classifier classifier; - REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5)); + REQUIRE(!classifier.GetClassificationResults(outputTens, resultVec, {}, 5, true)); } SECTION("Test classification results") diff --git a/tests/use_case/kws/InferenceTestDSCNN.cc b/tests/use_case/kws/InferenceTestDSCNN.cc deleted file mode 100644 index 8918073..0000000 --- a/tests/use_case/kws/InferenceTestDSCNN.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "DsCnnModel.hpp" -#include "hal.h" -#include "TestData_kws.hpp" -#include "TensorFlowLiteMicro.hpp" - -#include -#include - -using namespace test; - -bool RunInference(arm::app::Model& model, const int8_t vec[]) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - REQUIRE(inputTensor); - - const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? - inputTensor->bytes : - IFM_0_DATA_SIZE; - memcpy(inputTensor->data.data, vec, copySz); - - return model.RunInference(); -} - -bool RunInferenceRandom(arm::app::Model& model) -{ - TfLiteTensor* inputTensor = model.GetInputTensor(0); - REQUIRE(inputTensor); - - std::random_device rndDevice; - std::mt19937 mersenneGen{rndDevice()}; - std::uniform_int_distribution dist {-128, 127}; - - auto gen = [&dist, &mersenneGen](){ - return dist(mersenneGen); - }; - - std::vector randomAudio(inputTensor->bytes); - std::generate(std::begin(randomAudio), std::end(randomAudio), gen); - - REQUIRE(RunInference(model, randomAudio.data())); - return true; -} - -template -void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) -{ - REQUIRE(RunInference(model, input_goldenFV)); - - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - - REQUIRE(outputTensor); - REQUIRE(outputTensor->bytes == OFM_0_DATA_SIZE); - auto tensorData = tflite::GetTensorData(outputTensor); - REQUIRE(tensorData); - - for (size_t i = 0; i < outputTensor->bytes; i++) { - REQUIRE(static_cast(tensorData[i]) == static_cast(((T)output_goldenFV[i]))); - } -} - -TEST_CASE("Running random inference with TensorFlow Lite Micro and DsCnnModel Int8", "[DS_CNN]") -{ - arm::app::DsCnnModel model{}; - - REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); - REQUIRE(model.IsInited()); - - REQUIRE(RunInferenceRandom(model)); -} - -TEST_CASE("Running inference with TensorFlow Lite Micro and DsCnnModel Uint8", "[DS_CNN]") -{ - REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); - for (uint32_t i = 0 ; i < NUMBER_OF_IFM_FILES; ++i) { - const int8_t* input_goldenFV = get_ifm_data_array(i);; - const int8_t* output_goldenFV = get_ofm_data_array(i); - - DYNAMIC_SECTION("Executing inference with re-init") - { - arm::app::DsCnnModel model{}; - - REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); - REQUIRE(model.IsInited()); - - TestInference(input_goldenFV, output_goldenFV, model); - - } - } -} diff --git a/tests/use_case/kws/InferenceTestMicroNetKws.cc b/tests/use_case/kws/InferenceTestMicroNetKws.cc new file mode 100644 index 0000000..e6e7753 --- /dev/null +++ b/tests/use_case/kws/InferenceTestMicroNetKws.cc @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MicroNetKwsModel.hpp" +#include "hal.h" +#include "TestData_kws.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include +#include + +using namespace test; + +bool RunInference(arm::app::Model& model, const int8_t vec[]) +{ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + REQUIRE(inputTensor); + + const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? + inputTensor->bytes : + IFM_0_DATA_SIZE; + memcpy(inputTensor->data.data, vec, copySz); + + return model.RunInference(); +} + +bool RunInferenceRandom(arm::app::Model& model) +{ + TfLiteTensor* inputTensor = model.GetInputTensor(0); + REQUIRE(inputTensor); + + std::random_device rndDevice; + std::mt19937 mersenneGen{rndDevice()}; + std::uniform_int_distribution dist {-128, 127}; + + auto gen = [&dist, &mersenneGen](){ + return dist(mersenneGen); + }; + + std::vector randomAudio(inputTensor->bytes); + std::generate(std::begin(randomAudio), std::end(randomAudio), gen); + + REQUIRE(RunInference(model, randomAudio.data())); + return true; +} + +template +void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) +{ + REQUIRE(RunInference(model, input_goldenFV)); + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + + REQUIRE(outputTensor); + REQUIRE(outputTensor->bytes == OFM_0_DATA_SIZE); + auto tensorData = tflite::GetTensorData(outputTensor); + REQUIRE(tensorData); + + for (size_t i = 0; i < outputTensor->bytes; i++) { + REQUIRE(static_cast(tensorData[i]) == static_cast(((T)output_goldenFV[i]))); + } +} + +TEST_CASE("Running random inference with TensorFlow Lite Micro and MicroNetKwsModel Int8", "[MicroNetKws]") +{ + arm::app::MicroNetKwsModel model{}; + + REQUIRE_FALSE(model.IsInited()); + REQUIRE(model.Init()); + REQUIRE(model.IsInited()); + + REQUIRE(RunInferenceRandom(model)); +} + +TEST_CASE("Running inference with TensorFlow Lite Micro and MicroNetKwsModel int8", "[MicroNetKws]") +{ + REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); + for (uint32_t i = 0 ; i < NUMBER_OF_IFM_FILES; ++i) { + const int8_t* input_goldenFV = get_ifm_data_array(i);; + const int8_t* output_goldenFV = get_ofm_data_array(i); + + DYNAMIC_SECTION("Executing inference with re-init") + { + arm::app::MicroNetKwsModel model{}; + + REQUIRE_FALSE(model.IsInited()); + REQUIRE(model.Init()); + REQUIRE(model.IsInited()); + + TestInference(input_goldenFV, output_goldenFV, model); + + } + } +} diff --git a/tests/use_case/kws/KWSHandlerTest.cc b/tests/use_case/kws/KWSHandlerTest.cc index 50e5a83..a7e75fb 100644 --- a/tests/use_case/kws/KWSHandlerTest.cc +++ b/tests/use_case/kws/KWSHandlerTest.cc @@ -15,7 +15,7 @@ * limitations under the License. */ #include -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "hal.h" #include "KwsResult.hpp" @@ -27,7 +27,7 @@ TEST_CASE("Model info") { /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -53,7 +53,7 @@ TEST_CASE("Inference by index") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -65,8 +65,8 @@ TEST_CASE("Inference by index") caseContext.Set("profiler", profiler); caseContext.Set("platform", platform); caseContext.Set("model", model); - caseContext.Set("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */ - caseContext.Set("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */ + caseContext.Set("frameLength", g_FrameLength); /* 640 sample length for MicroNetKws. */ + caseContext.Set("frameStride", g_FrameStride); /* 320 sample stride for MicroNetKws. */ caseContext.Set("scoreThreshold", 0.5); /* Normalised score threshold. */ arm::app::Classifier classifier; /* classifier wrapper object. */ @@ -97,25 +97,25 @@ TEST_CASE("Inference by index") SECTION("Index = 0, short clip down") { /* Result: down. */ - checker(0, {5}); + checker(0, {0}); } SECTION("Index = 1, long clip right->left->up") { /* Result: right->right->left->up->up. */ - checker(1, {7, 1, 6, 4, 4}); + checker(1, {6, 6, 2, 8, 8}); } SECTION("Index = 2, short clip yes") { /* Result: yes. */ - checker(2, {2}); + checker(2, {9}); } SECTION("Index = 3, long clip yes->no->go->stop") { - /* Result: yes->go->no->go->go->go->stop. */ - checker(3, {2, 11, 3, 11, 11, 11, 10}); + /* Result: yes->no->no->go->go->stop->stop. */ + checker(3, {9, 3, 3, 1, 1, 7, 7}); } } @@ -132,7 +132,7 @@ TEST_CASE("Inference run all clips") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); @@ -145,9 +145,9 @@ TEST_CASE("Inference run all clips") caseContext.Set("platform", platform); caseContext.Set("model", model); caseContext.Set("clipIndex", 0); - caseContext.Set("frameLength", g_FrameLength); /* 640 sample length for DSCNN. */ - caseContext.Set("frameStride", g_FrameStride); /* 320 sample stride for DSCNN. */ - caseContext.Set("scoreThreshold", 0.9); /* Normalised score threshold. */ + caseContext.Set("frameLength", g_FrameLength); /* 640 sample length for MicroNet. */ + caseContext.Set("frameStride", g_FrameStride); /* 320 sample stride for MicroNet. */ + caseContext.Set("scoreThreshold", 0.7); /* Normalised score threshold. */ arm::app::Classifier classifier; /* classifier wrapper object. */ caseContext.Set("classifier", classifier); @@ -170,7 +170,7 @@ TEST_CASE("List all audio clips") hal_platform_init(&platform); /* Model wrapper object. */ - arm::app::DsCnnModel model; + arm::app::MicroNetKwsModel model; /* Load the model. */ REQUIRE(model.Init()); diff --git a/tests/use_case/kws/MfccTests.cc b/tests/use_case/kws/MfccTests.cc index 407861f..1d30ef4 100644 --- a/tests/use_case/kws/MfccTests.cc +++ b/tests/use_case/kws/MfccTests.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include #include @@ -93,13 +93,13 @@ const std::vector testWavMfcc { -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, }; -arm::app::audio::DsCnnMFCC GetMFCCInstance() { - const int sampFreq = arm::app::audio::DsCnnMFCC::ms_defaultSamplingFreq; +arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() { + const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; const int frameLenMs = 40; const int frameLenSamples = sampFreq * frameLenMs * 0.001; const int numMfccFeats = 10; - return arm::app::audio::DsCnnMFCC(numMfccFeats, frameLenSamples); + return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples); } template diff --git a/tests/use_case/kws_asr/InferenceTestDSCNN.cc b/tests/use_case/kws_asr/InferenceTestDSCNN.cc deleted file mode 100644 index ad1731b..0000000 --- a/tests/use_case/kws_asr/InferenceTestDSCNN.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "DsCnnModel.hpp" -#include "hal.h" -#include "TestData_kws.hpp" -#include "TensorFlowLiteMicro.hpp" - -#include -#include - -namespace test { -namespace kws { - -bool RunInference(arm::app::Model& model, const int8_t vec[]) { - TfLiteTensor* inputTensor = model.GetInputTensor(0); - REQUIRE(inputTensor); - - const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? - inputTensor->bytes : - IFM_0_DATA_SIZE; - memcpy(inputTensor->data.data, vec, copySz); - - return model.RunInference(); -} - -bool RunInferenceRandom(arm::app::Model& model) { - TfLiteTensor* inputTensor = model.GetInputTensor(0); - REQUIRE(inputTensor); - - std::random_device rndDevice; - std::mt19937 mersenneGen{rndDevice()}; - std::uniform_int_distribution dist{-128, 127}; - - auto gen = [&dist, &mersenneGen]() { - return dist(mersenneGen); - }; - - std::vector randomAudio(inputTensor->bytes); - std::generate(std::begin(randomAudio), std::end(randomAudio), gen); - - REQUIRE(RunInference(model, randomAudio.data())); - return true; -} - -template -void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) { - REQUIRE(RunInference(model, input_goldenFV)); - - TfLiteTensor* outputTensor = model.GetOutputTensor(0); - - REQUIRE(outputTensor); - REQUIRE(outputTensor->bytes == OFM_0_DATA_SIZE); - auto tensorData = tflite::GetTensorData(outputTensor); - REQUIRE(tensorData); - - for (size_t i = 0; i < outputTensor->bytes; i++) { - REQUIRE(static_cast(tensorData[i]) == static_cast((T) output_goldenFV[i])); - } -} - -TEST_CASE("Running random inference with Tflu and DsCnnModel Int8", "[DS_CNN]") { - arm::app::DsCnnModel model{}; - - REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); - REQUIRE(model.IsInited()); - - REQUIRE(RunInferenceRandom(model)); -} - -TEST_CASE("Running inference with Tflu and DsCnnModel Uint8", "[DS_CNN]") { - REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); - for (uint32_t i = 0; i < NUMBER_OF_IFM_FILES; ++i) { - const int8_t* input_goldenFV = get_ifm_data_array(i); - const int8_t* output_goldenFV = get_ofm_data_array(i); - - DYNAMIC_SECTION("Executing inference with re-init") { - arm::app::DsCnnModel model{}; - - REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); - REQUIRE(model.IsInited()); - - TestInference(input_goldenFV, output_goldenFV, model); - - } - } -} - -} //namespace -} //namespace \ No newline at end of file diff --git a/tests/use_case/kws_asr/InferenceTestMicroNetKws.cc b/tests/use_case/kws_asr/InferenceTestMicroNetKws.cc new file mode 100644 index 0000000..fd379b6 --- /dev/null +++ b/tests/use_case/kws_asr/InferenceTestMicroNetKws.cc @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2021 Arm Limited. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MicroNetKwsModel.hpp" +#include "hal.h" +#include "TestData_kws.hpp" +#include "TensorFlowLiteMicro.hpp" + +#include +#include + +namespace test { +namespace kws { + +bool RunInference(arm::app::Model& model, const int8_t vec[]) { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + REQUIRE(inputTensor); + + const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? + inputTensor->bytes : + IFM_0_DATA_SIZE; + memcpy(inputTensor->data.data, vec, copySz); + + return model.RunInference(); +} + +bool RunInferenceRandom(arm::app::Model& model) { + TfLiteTensor* inputTensor = model.GetInputTensor(0); + REQUIRE(inputTensor); + + std::random_device rndDevice; + std::mt19937 mersenneGen{rndDevice()}; + std::uniform_int_distribution dist{-128, 127}; + + auto gen = [&dist, &mersenneGen]() { + return dist(mersenneGen); + }; + + std::vector randomAudio(inputTensor->bytes); + std::generate(std::begin(randomAudio), std::end(randomAudio), gen); + + REQUIRE(RunInference(model, randomAudio.data())); + return true; +} + +template +void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) { + REQUIRE(RunInference(model, input_goldenFV)); + + TfLiteTensor* outputTensor = model.GetOutputTensor(0); + + REQUIRE(outputTensor); + REQUIRE(outputTensor->bytes == OFM_0_DATA_SIZE); + auto tensorData = tflite::GetTensorData(outputTensor); + REQUIRE(tensorData); + + for (size_t i = 0; i < outputTensor->bytes; i++) { + REQUIRE(static_cast(tensorData[i]) == static_cast((T) output_goldenFV[i])); + } +} + +TEST_CASE("Running random inference with Tflu and MicroNetKwsModel Int8", "[MicroNetKws]") { + arm::app::MicroNetKwsModel model{}; + + REQUIRE_FALSE(model.IsInited()); + REQUIRE(model.Init()); + REQUIRE(model.IsInited()); + + REQUIRE(RunInferenceRandom(model)); +} + +TEST_CASE("Running inference with Tflu and MicroNetKwsModel Int8", "[MicroNetKws]") { + REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); + for (uint32_t i = 0; i < NUMBER_OF_IFM_FILES; ++i) { + const int8_t* input_goldenFV = get_ifm_data_array(i); + const int8_t* output_goldenFV = get_ofm_data_array(i); + + DYNAMIC_SECTION("Executing inference with re-init") { + arm::app::MicroNetKwsModel model{}; + + REQUIRE_FALSE(model.IsInited()); + REQUIRE(model.Init()); + REQUIRE(model.IsInited()); + + TestInference(input_goldenFV, output_goldenFV, model); + + } + } +} + +} //namespace +} //namespace \ No newline at end of file diff --git a/tests/use_case/kws_asr/InitModels.cc b/tests/use_case/kws_asr/InitModels.cc index 770944d..97aa092 100644 --- a/tests/use_case/kws_asr/InitModels.cc +++ b/tests/use_case/kws_asr/InitModels.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnModel.hpp" +#include "MicroNetKwsModel.hpp" #include "Wav2LetterModel.hpp" #include @@ -22,8 +22,8 @@ /* Skip this test, Wav2LetterModel if not Vela optimized but only from ML-zoo will fail. */ TEST_CASE("Init two Models", "[.]") { - arm::app::DsCnnModel model1; - arm::app::DsCnnModel model2; + arm::app::MicroNetKwsModel model1; + arm::app::MicroNetKwsModel model2; /* Ideally we should load the wav2letter model here, but there is * none available to run on native (ops not supported on unoptimised diff --git a/tests/use_case/kws_asr/MfccTests.cc b/tests/use_case/kws_asr/MfccTests.cc index 9509519..c0fb723 100644 --- a/tests/use_case/kws_asr/MfccTests.cc +++ b/tests/use_case/kws_asr/MfccTests.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "DsCnnMfcc.hpp" +#include "MicroNetKwsMfcc.hpp" #include #include @@ -93,17 +93,17 @@ const std::vector testWavMfcc { -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, }; -arm::app::audio::DsCnnMFCC GetMFCCInstance() { - const int sampFreq = arm::app::audio::DsCnnMFCC::ms_defaultSamplingFreq; +arm::app::audio::MicroNetMFCC GetMFCCInstance() { + const int sampFreq = arm::app::audio::MicroNetMFCC::ms_defaultSamplingFreq; const int frameLenMs = 40; const int frameLenSamples = sampFreq * frameLenMs * 0.001; const int numMfccFeats = 10; - return arm::app::audio::DsCnnMFCC(numMfccFeats, frameLenSamples); + return arm::app::audio::MicroNetMFCC(numMfccFeats, frameLenSamples); } template -void TestQuntisedMFCC() { +void TestQuantisedMFCC() { const float quantScale = 1.1088106632232666; const int quantOffset = 95; std::vector mfccOutput = GetMFCCInstance().MfccComputeQuant(testWav, quantScale, quantOffset); @@ -118,9 +118,9 @@ void TestQuntisedMFCC() { REQUIRE(quantizedTestWavMfcc == Approx(mfccOutput[i]).margin(0)); } } -template void TestQuntisedMFCC(); -template void TestQuntisedMFCC(); -template void TestQuntisedMFCC(); +template void TestQuantisedMFCC(); +template void TestQuantisedMFCC(); +template void TestQuantisedMFCC(); TEST_CASE("MFCC calculation test") { @@ -141,16 +141,16 @@ TEST_CASE("MFCC calculation test") SECTION("int8_t") { - TestQuntisedMFCC(); + TestQuantisedMFCC(); } SECTION("uint8_t") { - TestQuntisedMFCC(); + TestQuantisedMFCC(); } SECTION("MFCC quant calculation test - int16_t") { - TestQuntisedMFCC(); + TestQuantisedMFCC(); } } \ No newline at end of file -- cgit v1.2.1