aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/keyword_spotting
diff options
context:
space:
mode:
authorRaviv Shalev <raviv.shalev@arm.com>2021-12-07 15:18:09 +0200
committerTeresaARM <teresa.charlinreyes@arm.com>2022-04-13 15:33:31 +0000
commit97ddc06e52fbcabfd8ede7a00e9494c663186b92 (patch)
tree43c84d352c3a67aa45d89760fba6c79b81c8f8dc /python/pyarmnn/examples/keyword_spotting
parent2f0ddb67d8f9267ab600a8a26308cab32f9e16ac (diff)
downloadarmnn-97ddc06e52fbcabfd8ede7a00e9494c663186b92.tar.gz
MLECO-2493 Add python OD example with TFLite delegate
Signed-off-by: Raviv Shalev <raviv.shalev@arm.com> Change-Id: I25fcccbf912be0c5bd4fbfd2e97552341958af35
Diffstat (limited to 'python/pyarmnn/examples/keyword_spotting')
-rw-r--r--python/pyarmnn/examples/keyword_spotting/README.MD2
-rw-r--r--python/pyarmnn/examples/keyword_spotting/run_audio_classification.py11
2 files changed, 8 insertions, 5 deletions
diff --git a/python/pyarmnn/examples/keyword_spotting/README.MD b/python/pyarmnn/examples/keyword_spotting/README.MD
index d276c08f8e..dde8342e7f 100644
--- a/python/pyarmnn/examples/keyword_spotting/README.MD
+++ b/python/pyarmnn/examples/keyword_spotting/README.MD
@@ -166,7 +166,7 @@ mfcc_feats = np.dot(self._dct_matrix, log_mel_energy)
# audio_utils.py
# Quantize the input data and create input tensors with PyArmNN
input_tensor = quantize_input(input_tensor, input_binding_info)
-input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+input_tensors = ann.make_input_tensors([input_binding_info], [input_data])
```
Note: `ArmnnNetworkExecutor` has already created the output tensors for you.
diff --git a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
index 6dfa4cc806..50ad1a8a2e 100644
--- a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
+++ b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
@@ -14,7 +14,7 @@ script_dir = os.path.dirname(__file__)
sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
from network_executor import ArmnnNetworkExecutor
-from utils import prepare_input_tensors, dequantize_output
+from utils import prepare_input_data, dequantize_output
from mfcc import AudioPreprocessor, MFCC, MFCCParams
from audio_utils import decode, display_text
from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio
@@ -69,13 +69,16 @@ def parse_args():
def recognise_speech(audio_data, network, preprocessor, threshold):
# Prepare the input Tensors
- input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
+ input_data = prepare_input_data(audio_data, network.get_data_type(), network.get_input_quantization_scale(0),
+ network.get_input_quantization_offset(0), preprocessor)
# Run inference
- output_result = network.run(input_tensors)
+ output_result = network.run([input_data])
dequantized_result = []
for index, ofm in enumerate(output_result):
- dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index]))
+ dequantized_result.append(dequantize_output(ofm, network.is_output_quantized(index),
+ network.get_output_quantization_scale(index),
+ network.get_output_quantization_offset(index)))
# Decode the text and display result if above threshold
decoded_result = decode(dequantized_result, labels)