aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/keyword_spotting/run_audio_classification.py')
-rw-r--r--python/pyarmnn/examples/keyword_spotting/run_audio_classification.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
index 6dfa4cc806..50ad1a8a2e 100644
--- a/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
+++ b/python/pyarmnn/examples/keyword_spotting/run_audio_classification.py
@@ -14,7 +14,7 @@ script_dir = os.path.dirname(__file__)
sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
from network_executor import ArmnnNetworkExecutor
-from utils import prepare_input_tensors, dequantize_output
+from utils import prepare_input_data, dequantize_output
from mfcc import AudioPreprocessor, MFCC, MFCCParams
from audio_utils import decode, display_text
from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio
@@ -69,13 +69,16 @@ def parse_args():
def recognise_speech(audio_data, network, preprocessor, threshold):
# Prepare the input Tensors
- input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
+ input_data = prepare_input_data(audio_data, network.get_data_type(), network.get_input_quantization_scale(0),
+ network.get_input_quantization_offset(0), preprocessor)
# Run inference
- output_result = network.run(input_tensors)
+ output_result = network.run([input_data])
dequantized_result = []
for index, ofm in enumerate(output_result):
- dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index]))
+ dequantized_result.append(dequantize_output(ofm, network.is_output_quantized(index),
+ network.get_output_quantization_scale(index),
+ network.get_output_quantization_offset(index)))
# Decode the text and display result if above threshold
decoded_result = decode(dequantized_result, labels)