aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/speech_recognition/audio_utils.py
blob: f03d2e1290e829d00cac7995c25cf29334a73dae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT

"""Utilities for speech recognition apps."""

import numpy as np
import pyarmnn as ann


def decode(model_output: np.ndarray, labels: dict) -> str:
    """Decodes the integer encoded results from inference into a string.

    Args:
        model_output: Results from running inference.
        labels: Dictionary of labels keyed on the classification index.

    Returns:
        Decoded string.
    """
    top1_results = [labels[np.argmax(row)] for row in model_output]
    return filter_characters(top1_results)


def filter_characters(results: list) -> str:
    """Filters unwanted and duplicate characters.

    Args:
        results: List of top 1 results from inference.

    Returns:
        Final output string to present to user.
    """
    text = ""
    for i in range(len(results)):
        if results[i] == "$":
            continue
        elif i + 1 < len(results) and results[i] == results[i + 1]:
            continue
        else:
            text += results[i]
    return text


def display_text(text: str):
    """Presents the results on the console.

    Args:
        text: Results of performing ASR on the input audio data.
    """
    print(text, sep="", end="", flush=True)


def quantize_input(data, input_binding_info):
    """Quantize the float input to (u)int8 ready for inputting to model."""
    if data.ndim != 2:
        raise RuntimeError("Audio data must have 2 dimensions for quantization")

    quant_scale = input_binding_info[1].GetQuantizationScale()
    quant_offset = input_binding_info[1].GetQuantizationOffset()
    data_type = input_binding_info[1].GetDataType()

    if data_type == ann.DataType_QAsymmS8:
        data_type = np.int8
    elif data_type == ann.DataType_QAsymmU8:
        data_type = np.uint8
    else:
        raise ValueError("Could not quantize data to required data type")

    d_min = np.iinfo(data_type).min
    d_max = np.iinfo(data_type).max

    for row in range(data.shape[0]):
        for col in range(data.shape[1]):
            data[row, col] = (data[row, col] / quant_scale) + quant_offset
            data[row, col] = np.clip(data[row, col], d_min, d_max)
    data = data.astype(data_type)
    return data


def decode_text(is_first_window, labels, output_result):
    """
    Slices the text appropriately depending on the window, and decodes for wav2letter output.
        * First run, take the left context, and inner context.
        * Every other run, take the inner context.
    Stores the current right context, and updates it for each inference. Will get used after last inference.

    Args:
        is_first_window: Boolean to show if it is the first window we are running inference on
        labels: the label set
        output_result: the output from the inference
        text: the current text string, to be displayed at the end
    Returns:
        current_r_context: the current right context
        text: the current text string, with the latest output decoded and appended
    """
    # For wav2letter with 148 output steps:
    # Left context is index 0-48, inner context 49-99, right context 100-147
    inner_context_start = 49
    inner_context_end = 99
    right_context_start = 100

    if is_first_window:
        # Since it's the first inference, keep the left context, and inner context, and decode
        text = decode(output_result[0][0][0][0:inner_context_end], labels)
    else:
        # Only decode the inner context
        text = decode(output_result[0][0][0][inner_context_start:inner_context_end], labels)

    # Store the right context, we will need it after the last inference
    current_r_context = decode(output_result[0][0][0][right_context_start:], labels)
    return current_r_context, text


def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
    """
    Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
    input tensors.

    Args:
        audio_data: The audio data to process
        mfcc_instance: the mfcc class instance
        input_binding_info: the model input binding info
        mfcc_preprocessor: the mfcc preprocessor instance
    Returns:
        input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
    """

    data_type = input_binding_info[1].GetDataType()
    input_tensor = mfcc_preprocessor.extract_features(audio_data)
    if data_type != ann.DataType_Float32:
        input_tensor = quantize_input(input_tensor, input_binding_info)
    input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
    return input_tensors