From a04f2f7322e7b83d93e875313d2e5b4d0dca94fb Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Wed, 6 Jul 2022 13:42:24 +0200 Subject: MLBEDSW-6687 Vela crashes in npu_serialisation.py and tflite_graph_optimiser.py Fixed static optimisation of Quantize operator by running unsupported formats on CPU. Also added support for int16 and corrected the calculation. Change-Id: I861c712aa6258dba53fcf4d5dae45d1d416e6141 Signed-off-by: Fredrik Svedberg --- ethosu/vela/tflite_graph_optimiser.py | 56 ++++++++++++++++------------------- 1 file changed, 25 insertions(+), 31 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 10ddca60..f2a8c803 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -25,6 +25,7 @@ from . import fp_math from . import rewrite_graph from . import scaling from .api import NpuRoundingMode +from .data_type import BaseType from .data_type import DataType from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError @@ -1408,40 +1409,44 @@ def optimise_quantize(op: Operation, arch, nng): input_values = np.array([input_values]) # requantized int8 to int8 - if ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8: + if (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) or ( + ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16 + ): # scale needs to use double precision to match TFLite reference kernel effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32) effective_multiplier, effective_shift = quantise_scale(effective_scale) - assert effective_shift >= 0 - assert -31 <= effective_shift <= 30 - round_val = 1 << (effective_shift - 1) - requantized_vals = [] - for val in input_values: + for val in input_values.flatten(): input_val = val - ifm.quantization.zero_point - output = input_val * effective_multiplier + round_val - ofm_val = (output >> effective_shift) + ofm.quantization.zero_point + ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift) + ofm_val += ofm.quantization.zero_point - clamped_ofm_values = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min) - requantized_vals.append(clamped_ofm_values) + clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min) + requantized_vals.append(clamped_ofm_value) - ofm.values = np.array(requantized_vals) + ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type()) + ofm.values.shape = input_values.shape # Case: Float input - quantize to int - elif np.issubdtype(input_values.dtype, np.float): + elif ifm.dtype.type == BaseType.Float: quantized_vals = [] for val in input_values: # Derive quantized value quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point - quantized_vals.append(quant_val) + clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max) + quantized_vals.append(clamped_quantized_val) # Pass the statically calculated quant val to output tensor - ofm.values = np.array(quantized_vals) + ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type()) + + # Unsupported data type + else: + return op # Make quantize op const and disconnect from parent node @@ -1493,23 +1498,6 @@ def tflite_optimise_graph(nng, arch): # Compile time optimisations optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor] - for optimisation in optimisation_list: - for idx, sg in enumerate(nng.subgraphs): - nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, - sg, - arch, - [], - [optimisation], - rewrite_unsupported=False, - ) - - # Pre-processing step - pre_process_list = [ - supported_operator_check, - set_ifm_ofm_op_shapes, - ] - for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( nng, @@ -1520,6 +1508,12 @@ def tflite_optimise_graph(nng, arch): rewrite_unsupported=False, ) + # Pre-processing step + pre_process_list = [ + supported_operator_check, + set_ifm_ofm_op_shapes, + ] + for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( nng, -- cgit v1.2.1