aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/common
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/common')
-rw-r--r--python/pyarmnn/examples/common/audio_capture.py149
-rw-r--r--python/pyarmnn/examples/common/cv_utils.py8
-rw-r--r--python/pyarmnn/examples/common/mfcc.py238
-rw-r--r--python/pyarmnn/examples/common/tests/conftest.py40
-rw-r--r--python/pyarmnn/examples/common/tests/context.py7
-rw-r--r--python/pyarmnn/examples/common/tests/test_network_executor.py24
-rw-r--r--python/pyarmnn/examples/common/tests/test_utils.py19
-rw-r--r--python/pyarmnn/examples/common/utils.py69
8 files changed, 460 insertions, 94 deletions
diff --git a/python/pyarmnn/examples/common/audio_capture.py b/python/pyarmnn/examples/common/audio_capture.py
new file mode 100644
index 0000000000..1bd53b4473
--- /dev/null
+++ b/python/pyarmnn/examples/common/audio_capture.py
@@ -0,0 +1,149 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+"""Contains CaptureAudioStream class for capturing chunks of audio data from incoming
+ stream and generic capture_audio function for capturing from files."""
+import collections
+import time
+from queue import Queue
+from typing import Generator
+
+import numpy as np
+import sounddevice as sd
+import soundfile as sf
+
+AudioCaptureParams = collections.namedtuple('AudioCaptureParams',
+ ['dtype', 'overlap', 'min_samples', 'sampling_freq', 'mono'])
+
+
+def capture_audio(audio_file_path, params_tuple) -> Generator[np.ndarray, None, None]:
+ """Creates a generator that yields audio data from a file. Data is padded with
+ zeros if necessary to make up minimum number of samples.
+ Args:
+ audio_file_path: Path to audio file provided by user.
+ params_tuple: Sampling parameters for model used
+ Yields:
+ Blocks of audio data of minimum sample size.
+ """
+ with sf.SoundFile(audio_file_path) as audio_file:
+ for block in audio_file.blocks(
+ blocksize=params_tuple.min_samples,
+ dtype=params_tuple.dtype,
+ always_2d=True,
+ fill_value=0,
+ overlap=params_tuple.overlap
+ ):
+ if params_tuple.mono and block.shape[0] > 1:
+ block = np.mean(block, dtype=block.dtype, axis=1)
+ yield block
+
+
+class CaptureAudioStream:
+
+ def __init__(self, audio_capture_params):
+ self.audio_capture_params = audio_capture_params
+ self.collection = np.zeros(self.audio_capture_params.min_samples + self.audio_capture_params.overlap).astype(
+ dtype=self.audio_capture_params.dtype)
+ self.is_active = True
+ self.is_first_window = True
+ self.duration = False
+ self.block_count = 0
+ self.current_block = 0
+ self.queue = Queue(2)
+
+ def set_stream_defaults(self):
+ """Discovers input devices on the system and sets default stream parameters."""
+ print(sd.query_devices())
+ device = input("Select input device by index or name: ")
+
+ try:
+ sd.default.device = int(device)
+ except ValueError:
+ sd.default.device = str(device)
+
+ sd.default.samplerate = self.audio_capture_params.sampling_freq
+ sd.default.blocksize = self.audio_capture_params.min_samples
+ sd.default.dtype = self.audio_capture_params.dtype
+ sd.default.channels = 1 if self.audio_capture_params.mono else 2
+
+ def set_recording_duration(self, duration):
+ """Sets a time duration (in integer seconds) for recording audio. Total time duration is
+ adjusted to a minimum based on the parameters of the model used. Durations less than 1
+ result in endless recording.
+
+ Args:
+ duration (int): User-provided command line argument for time duration of recording.
+ """
+ if duration > 0:
+ min_duration = int(
+ np.ceil(self.audio_capture_params.min_samples / self.audio_capture_params.sampling_freq)
+ )
+ if duration < min_duration:
+ print(f"Minimum duration must be {min_duration} seconds of audio")
+ print(f"Setting minimum recording duration...")
+ duration = min_duration
+
+ print(f"Recording duration is {duration} seconds")
+ self.duration = self.audio_capture_params.sampling_freq * duration
+ self.block_count, remainder_samples = divmod(
+ self.duration, self.audio_capture_params.min_samples
+ )
+
+ if remainder_samples > 0.5 * self.audio_capture_params.sampling_freq:
+ self.block_count += 1
+ else:
+ self.duration = False # Record forever
+
+ def countdown(self, delay=3):
+ """3 second countdown prior to recording audio."""
+ print("Beginning recording in...")
+ for i in range(delay, 0, -1):
+ print(f"{i}...")
+ time.sleep(1)
+
+ def update(self):
+ """If a duration has been set, increments a counter to update the number of blocks of audio
+ data left to be collected. The stream is deactivated upon reaching the maximum block count
+ determined by the duration.
+ """
+ if self.duration:
+ self.current_block += 1
+ if self.current_block == self.block_count:
+ self.is_active = False
+
+ def capture_data(self):
+ """Gets the next window of audio data by retrieving the newest data from a queue and
+ shifting the position of the data in the collection. Overlap values of less than `min_samples` are supported.
+ """
+ new_data = self.queue.get()
+
+ if self.is_first_window or self.audio_capture_params.overlap == 0:
+ self.collection[:self.audio_capture_params.min_samples] = new_data[:]
+
+ elif self.audio_capture_params.overlap < self.audio_capture_params.min_samples:
+ #
+ self.collection[0:self.audio_capture_params.overlap] = \
+ self.collection[(self.audio_capture_params.min_samples - self.audio_capture_params.overlap):
+ self.audio_capture_params.min_samples]
+
+ self.collection[self.audio_capture_params.overlap:(
+ self.audio_capture_params.overlap + self.audio_capture_params.min_samples)] = new_data[:]
+ else:
+ raise ValueError(
+ "Capture Error: Overlap must be less than {}".format(self.audio_capture_params.min_samples))
+ audio_data = self.collection[0:self.audio_capture_params.min_samples]
+ return np.asarray(audio_data).astype(self.audio_capture_params.dtype)
+
+ def callback(self, data, frames, time, status):
+ """Places audio data from active stream into a queue for processing.
+ Update counter if recording duration is finite.
+ """
+
+ if self.duration:
+ self.update()
+
+ if self.audio_capture_params.mono:
+ audio_data = data.copy().flatten()
+ else:
+ audio_data = data.copy()
+
+ self.queue.put(audio_data)
diff --git a/python/pyarmnn/examples/common/cv_utils.py b/python/pyarmnn/examples/common/cv_utils.py
index fd848b8b0f..e12ff50548 100644
--- a/python/pyarmnn/examples/common/cv_utils.py
+++ b/python/pyarmnn/examples/common/cv_utils.py
@@ -1,4 +1,4 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2020-2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""
@@ -14,7 +14,7 @@ import numpy as np
import pyarmnn as ann
-def preprocess(frame: np.ndarray, input_binding_info: tuple):
+def preprocess(frame: np.ndarray, input_binding_info: tuple, is_normalised: bool):
"""
Takes a frame, resizes, swaps channels and converts data type to match
model input layer. The converted frame is wrapped in a const tensor
@@ -23,6 +23,7 @@ def preprocess(frame: np.ndarray, input_binding_info: tuple):
Args:
frame: Captured frame from video.
input_binding_info: Contains shape and data type of model input layer.
+ is_normalised: if the input layer expects normalised data
Returns:
Input tensor.
@@ -34,7 +35,8 @@ def preprocess(frame: np.ndarray, input_binding_info: tuple):
# Expand dimensions and convert data type to match model input
if input_binding_info[1].GetDataType() == ann.DataType_Float32:
data_type = np.float32
- resized_frame = resized_frame.astype("float32")/255
+ if is_normalised:
+ resized_frame = resized_frame.astype("float32")/255
else:
data_type = np.uint8
diff --git a/python/pyarmnn/examples/common/mfcc.py b/python/pyarmnn/examples/common/mfcc.py
new file mode 100644
index 0000000000..2bab669fb7
--- /dev/null
+++ b/python/pyarmnn/examples/common/mfcc.py
@@ -0,0 +1,238 @@
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Class used to extract the Mel-frequency cepstral coefficients from a given audio frame."""
+
+import numpy as np
+import collections
+
+MFCCParams = collections.namedtuple('MFCCParams', ['sampling_freq', 'num_fbank_bins', 'mel_lo_freq', 'mel_hi_freq',
+ 'num_mfcc_feats', 'frame_len', 'use_htk_method', 'n_fft'])
+
+
+class MFCC:
+
+ def __init__(self, mfcc_params):
+ self.mfcc_params = mfcc_params
+ self.FREQ_STEP = 200.0 / 3
+ self.MIN_LOG_HZ = 1000.0
+ self.MIN_LOG_MEL = self.MIN_LOG_HZ / self.FREQ_STEP
+ self.LOG_STEP = 1.8562979903656 / 27.0
+ self._frame_len_padded = int(2 ** (np.ceil((np.log(self.mfcc_params.frame_len) / np.log(2.0)))))
+ self._filter_bank_initialised = False
+ self.__frame = np.zeros(self._frame_len_padded)
+ self.__buffer = np.zeros(self._frame_len_padded)
+ self._filter_bank_filter_first = np.zeros(self.mfcc_params.num_fbank_bins)
+ self._filter_bank_filter_last = np.zeros(self.mfcc_params.num_fbank_bins)
+ self.__mel_energies = np.zeros(self.mfcc_params.num_fbank_bins)
+ self._dct_matrix = self.create_dct_matrix(self.mfcc_params.num_fbank_bins, self.mfcc_params.num_mfcc_feats)
+ self.__mel_filter_bank = self.create_mel_filter_bank()
+ self._np_mel_bank = np.zeros([self.mfcc_params.num_fbank_bins, int(self.mfcc_params.n_fft / 2) + 1])
+
+ for i in range(self.mfcc_params.num_fbank_bins):
+ k = 0
+ for j in range(int(self._filter_bank_filter_first[i]), int(self._filter_bank_filter_last[i]) + 1):
+ self._np_mel_bank[i, j] = self.__mel_filter_bank[i][k]
+ k += 1
+
+ def mel_scale(self, freq, use_htk_method):
+ """
+ Gets the mel scale for a particular sample frequency.
+
+ Args:
+ freq: The sampling frequency.
+ use_htk_method: Boolean to set whether to use HTK method or not.
+
+ Returns:
+ the mel scale
+ """
+ if use_htk_method:
+ return 1127.0 * np.log(1.0 + freq / 700.0)
+ else:
+ mel = freq / self.FREQ_STEP
+
+ if freq >= self.MIN_LOG_HZ:
+ mel = self.MIN_LOG_MEL + np.log(freq / self.MIN_LOG_HZ) / self.LOG_STEP
+ return mel
+
+ def inv_mel_scale(self, mel_freq, use_htk_method):
+ """
+ Gets the sample frequency for a particular mel.
+
+ Args:
+ mel_freq: The mel frequency.
+ use_htk_method: Boolean to set whether to use HTK method or not.
+
+ Returns:
+ the sample frequency
+ """
+ if use_htk_method:
+ return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
+ else:
+ freq = self.FREQ_STEP * mel_freq
+
+ if mel_freq >= self.MIN_LOG_MEL:
+ freq = self.MIN_LOG_HZ * np.exp(self.LOG_STEP * (mel_freq - self.MIN_LOG_MEL))
+ return freq
+
+ def spectrum_calc(self, audio_data):
+ return np.abs(np.fft.rfft(np.hanning(self.mfcc_params.frame_len + 1)[0:self.mfcc_params.frame_len] * audio_data,
+ self.mfcc_params.n_fft))
+
+ def log_mel(self, mel_energy):
+ mel_energy += 1e-10 # Avoid division by zero
+ return np.log(mel_energy)
+
+ def mfcc_compute(self, audio_data):
+ """
+ Extracts the MFCC for a single frame.
+
+ Args:
+ audio_data: The audio data to process.
+
+ Returns:
+ the MFCC features
+ """
+ if len(audio_data) != self.mfcc_params.frame_len:
+ raise ValueError(
+ f"audio_data buffer size {len(audio_data)} does not match frame length {self.mfcc_params.frame_len}")
+
+ audio_data = np.array(audio_data)
+ spec = self.spectrum_calc(audio_data)
+ mel_energy = np.dot(self._np_mel_bank.astype(np.float32),
+ np.transpose(spec).astype(np.float32))
+ log_mel_energy = self.log_mel(mel_energy)
+ mfcc_feats = np.dot(self._dct_matrix, log_mel_energy)
+ return mfcc_feats
+
+ def create_dct_matrix(self, num_fbank_bins, num_mfcc_feats):
+ """
+ Creates the Discrete Cosine Transform matrix to be used in the compute function.
+
+ Args:
+ num_fbank_bins: The number of filter bank bins
+ num_mfcc_feats: the number of MFCC features
+
+ Returns:
+ the DCT matrix
+ """
+
+ dct_m = np.zeros(num_fbank_bins * num_mfcc_feats)
+ for k in range(num_mfcc_feats):
+ for n in range(num_fbank_bins):
+ dct_m[(k * num_fbank_bins) + n] = (np.sqrt(2 / num_fbank_bins)) * np.cos(
+ (np.pi / num_fbank_bins) * (n + 0.5) * k)
+ dct_m = np.reshape(dct_m, [self.mfcc_params.num_mfcc_feats, self.mfcc_params.num_fbank_bins])
+ return dct_m
+
+ def mel_norm(self, weight, right_mel, left_mel):
+ """
+ Placeholder function over-ridden in child class
+ """
+ return weight
+
+ def create_mel_filter_bank(self):
+ """
+ Creates the Mel filter bank.
+
+ Returns:
+ the mel filter bank
+ """
+ # FFT calculations are greatly accelerated for frame lengths which are powers of 2
+ # Frames are padded and FFT bin width/length calculated accordingly
+ num_fft_bins = int(self._frame_len_padded / 2)
+ fft_bin_width = self.mfcc_params.sampling_freq / self._frame_len_padded
+
+ mel_low_freq = self.mel_scale(self.mfcc_params.mel_lo_freq, self.mfcc_params.use_htk_method)
+ mel_high_freq = self.mel_scale(self.mfcc_params.mel_hi_freq, self.mfcc_params.use_htk_method)
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (self.mfcc_params.num_fbank_bins + 1)
+
+ this_bin = np.zeros(num_fft_bins)
+ mel_fbank = [0] * self.mfcc_params.num_fbank_bins
+ for bin_num in range(self.mfcc_params.num_fbank_bins):
+ left_mel = mel_low_freq + bin_num * mel_freq_delta
+ center_mel = mel_low_freq + (bin_num + 1) * mel_freq_delta
+ right_mel = mel_low_freq + (bin_num + 2) * mel_freq_delta
+ first_index = last_index = -1
+
+ for i in range(num_fft_bins):
+ freq = (fft_bin_width * i)
+ mel = self.mel_scale(freq, self.mfcc_params.use_htk_method)
+ this_bin[i] = 0.0
+
+ if (mel > left_mel) and (mel < right_mel):
+ if mel <= center_mel:
+ weight = (mel - left_mel) / (center_mel - left_mel)
+ else:
+ weight = (right_mel - mel) / (right_mel - center_mel)
+
+ this_bin[i] = self.mel_norm(weight, right_mel, left_mel)
+
+ if first_index == -1:
+ first_index = i
+ last_index = i
+
+ self._filter_bank_filter_first[bin_num] = first_index
+ self._filter_bank_filter_last[bin_num] = last_index
+ mel_fbank[bin_num] = np.zeros(last_index - first_index + 1)
+ j = 0
+
+ for i in range(first_index, last_index + 1):
+ mel_fbank[bin_num][j] = this_bin[i]
+ j += 1
+
+ return mel_fbank
+
+
+class AudioPreprocessor:
+
+ def __init__(self, mfcc, model_input_size, stride):
+ self.model_input_size = model_input_size
+ self.stride = stride
+ self._mfcc_calc = mfcc
+
+ def _normalize(self, values):
+ """
+ Normalize values to mean 0 and std 1
+ """
+ ret_val = (values - np.mean(values)) / np.std(values)
+ return ret_val
+
+ def _get_features(self, features, mfcc_instance, audio_data):
+ idx = 0
+ while len(features) < self.model_input_size * mfcc_instance.mfcc_params.num_mfcc_feats:
+ current_frame_feats = mfcc_instance.mfcc_compute(audio_data[idx:idx + int(mfcc_instance.mfcc_params.frame_len)])
+ features.extend(current_frame_feats)
+ idx += self.stride
+
+ def mfcc_delta_calc(self, features):
+ """
+ Placeholder function over-ridden in child class
+ """
+ return features
+
+ def extract_features(self, audio_data):
+ """
+ Extracts the MFCC features. Also calculates each features first and second order derivatives
+ if the mfcc_delta_calc() function has been implemented by a child class.
+ The matrix returned should be sized appropriately for input to the model, based
+ on the model info specified in the MFCC instance.
+
+ Args:
+ audio_data: the audio data to be used for this calculation
+ Returns:
+ the derived MFCC feature vector, sized appropriately for inference
+ """
+
+ num_samples_per_inference = ((self.model_input_size - 1)
+ * self.stride) + self._mfcc_calc.mfcc_params.frame_len
+
+ if len(audio_data) < num_samples_per_inference:
+ raise ValueError("audio_data size for feature extraction is smaller than "
+ "the expected number of samples needed for inference")
+
+ features = []
+ self._get_features(features, self._mfcc_calc, np.asarray(audio_data))
+ features = np.reshape(np.array(features), (self.model_input_size, self._mfcc_calc.mfcc_params.num_mfcc_feats))
+ features = self.mfcc_delta_calc(features)
+ return np.float32(features)
diff --git a/python/pyarmnn/examples/common/tests/conftest.py b/python/pyarmnn/examples/common/tests/conftest.py
deleted file mode 100644
index 5e027a0125..0000000000
--- a/python/pyarmnn/examples/common/tests/conftest.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-import ntpath
-
-import urllib.request
-import zipfile
-
-import pytest
-
-script_dir = os.path.dirname(__file__)
-@pytest.fixture(scope="session")
-def test_data_folder(request):
- """
- This fixture returns path to folder with shared test resources among all tests
- """
-
- data_dir = os.path.join(script_dir, "testdata")
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
-
- files_to_download = ["https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/messi5.jpg",
- "https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/basketball1.png",
- "https://raw.githubusercontent.com/opencv/opencv/4.0.0/samples/data/Megamind.avi",
- "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip"
- ]
-
- for file in files_to_download:
- path, filename = ntpath.split(file)
- file_path = os.path.join(data_dir, filename)
- if not os.path.exists(file_path):
- print("\nDownloading test file: " + file_path + "\n")
- urllib.request.urlretrieve(file, file_path)
-
- # Any unzipping needed, and moving around of files
- with zipfile.ZipFile(os.path.join(data_dir, "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip"), 'r') as zip_ref:
- zip_ref.extractall(data_dir)
-
- return data_dir
diff --git a/python/pyarmnn/examples/common/tests/context.py b/python/pyarmnn/examples/common/tests/context.py
deleted file mode 100644
index 72246c03bf..0000000000
--- a/python/pyarmnn/examples/common/tests/context.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import os
-import sys
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
-
-import cv_utils
-import network_executor
-import utils
diff --git a/python/pyarmnn/examples/common/tests/test_network_executor.py b/python/pyarmnn/examples/common/tests/test_network_executor.py
deleted file mode 100644
index e27b382078..0000000000
--- a/python/pyarmnn/examples/common/tests/test_network_executor.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-
-import cv2
-
-from context import network_executor
-from context import cv_utils
-
-
-def test_execute_network(test_data_folder):
- model_path = os.path.join(test_data_folder, "detect.tflite")
- backends = ["CpuAcc", "CpuRef"]
-
- executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
- img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
- input_tensors = cv_utils.preprocess(img, executor.input_binding_info)
-
- output_result = executor.run(input_tensors)
-
- # Ensure it detects a person
- classes = output_result[1]
- assert classes[0][0] == 0
diff --git a/python/pyarmnn/examples/common/tests/test_utils.py b/python/pyarmnn/examples/common/tests/test_utils.py
deleted file mode 100644
index 28d68ea235..0000000000
--- a/python/pyarmnn/examples/common/tests/test_utils.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
-# SPDX-License-Identifier: MIT
-
-import os
-
-from context import cv_utils
-from context import utils
-
-
-def test_get_source_encoding(test_data_folder):
- video_file = os.path.join(test_data_folder, "Megamind.avi")
- video, video_writer, frame_count = cv_utils.init_video_file_capture(video_file, "/tmp")
- assert cv_utils.get_source_encoding_int(video) == 1145656920
-
-
-def test_read_existing_labels_file(test_data_folder):
- label_file = os.path.join(test_data_folder, "labelmap.txt")
- labels_map = utils.dict_labels(label_file)
- assert labels_map is not None
diff --git a/python/pyarmnn/examples/common/utils.py b/python/pyarmnn/examples/common/utils.py
index cf09fdefb8..d4dadf80a4 100644
--- a/python/pyarmnn/examples/common/utils.py
+++ b/python/pyarmnn/examples/common/utils.py
@@ -1,4 +1,4 @@
-# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""Contains helper functions that can be used across the example apps."""
@@ -8,6 +8,7 @@ import errno
from pathlib import Path
import numpy as np
+import pyarmnn as ann
def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
@@ -39,3 +40,69 @@ def dict_labels(labels_file_path: str, include_rgb=False) -> dict:
else:
labels[idx] = line.strip("\n")
return labels
+
+
+def prepare_input_tensors(audio_data, input_binding_info, mfcc_preprocessor):
+ """
+ Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
+ input tensors.
+
+ Args:
+ audio_data: The audio data to process
+ mfcc_instance: the mfcc class instance
+ input_binding_info: the model input binding info
+ mfcc_preprocessor: the mfcc preprocessor instance
+ Returns:
+ input_tensors: the prepared input tensors, ready to be consumed by the ArmNN NetworkExecutor
+ """
+
+ data_type = input_binding_info[1].GetDataType()
+ input_tensor = mfcc_preprocessor.extract_features(audio_data)
+ if data_type != ann.DataType_Float32:
+ input_tensor = quantize_input(input_tensor, input_binding_info)
+ input_tensors = ann.make_input_tensors([input_binding_info], [input_tensor])
+ return input_tensors
+
+
+def quantize_input(data, input_binding_info):
+ """Quantize the float input to (u)int8 ready for inputting to model."""
+ if data.ndim != 2:
+ raise RuntimeError("Audio data must have 2 dimensions for quantization")
+
+ quant_scale = input_binding_info[1].GetQuantizationScale()
+ quant_offset = input_binding_info[1].GetQuantizationOffset()
+ data_type = input_binding_info[1].GetDataType()
+
+ if data_type == ann.DataType_QAsymmS8:
+ data_type = np.int8
+ elif data_type == ann.DataType_QAsymmU8:
+ data_type = np.uint8
+ else:
+ raise ValueError("Could not quantize data to required data type")
+
+ d_min = np.iinfo(data_type).min
+ d_max = np.iinfo(data_type).max
+
+ for row in range(data.shape[0]):
+ for col in range(data.shape[1]):
+ data[row, col] = (data[row, col] / quant_scale) + quant_offset
+ data[row, col] = np.clip(data[row, col], d_min, d_max)
+ data = data.astype(data_type)
+ return data
+
+
+def dequantize_output(data, output_binding_info):
+ """Dequantize the (u)int8 output to float"""
+
+ if output_binding_info[1].IsQuantized():
+ if data.ndim != 2:
+ raise RuntimeError("Data must have 2 dimensions for quantization")
+
+ quant_scale = output_binding_info[1].GetQuantizationScale()
+ quant_offset = output_binding_info[1].GetQuantizationOffset()
+
+ data = data.astype(float)
+ for row in range(data.shape[0]):
+ for col in range(data.shape[1]):
+ data[row, col] = (data[row, col] - quant_offset)*quant_scale
+ return data