diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2022-07-06 13:42:24 +0200 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2022-07-13 15:18:20 +0000 |
commit | a04f2f7322e7b83d93e875313d2e5b4d0dca94fb (patch) | |
tree | c630418ccf8ca22b6ffae2ca9865967d7c7ac0f1 /ethosu/vela/tflite_graph_optimiser.py | |
parent | c4d35eb580902dfe6acedb2db3a72c32760f86af (diff) | |
download | ethos-u-vela-a04f2f7322e7b83d93e875313d2e5b4d0dca94fb.tar.gz |
MLBEDSW-6687 Vela crashes in npu_serialisation.py and tflite_graph_optimiser.py
Fixed static optimisation of Quantize operator by running unsupported
formats on CPU. Also added support for int16 and corrected the
calculation.
Change-Id: I861c712aa6258dba53fcf4d5dae45d1d416e6141
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 56 |
1 files changed, 25 insertions, 31 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 10ddca60..f2a8c803 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -25,6 +25,7 @@ from . import fp_math from . import rewrite_graph from . import scaling from .api import NpuRoundingMode +from .data_type import BaseType from .data_type import DataType from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError @@ -1408,40 +1409,44 @@ def optimise_quantize(op: Operation, arch, nng): input_values = np.array([input_values]) # requantized int8 to int8 - if ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8: + if (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) or ( + ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16 + ): # scale needs to use double precision to match TFLite reference kernel effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32) effective_multiplier, effective_shift = quantise_scale(effective_scale) - assert effective_shift >= 0 - assert -31 <= effective_shift <= 30 - round_val = 1 << (effective_shift - 1) - requantized_vals = [] - for val in input_values: + for val in input_values.flatten(): input_val = val - ifm.quantization.zero_point - output = input_val * effective_multiplier + round_val - ofm_val = (output >> effective_shift) + ofm.quantization.zero_point + ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift) + ofm_val += ofm.quantization.zero_point - clamped_ofm_values = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min) - requantized_vals.append(clamped_ofm_values) + clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min) + requantized_vals.append(clamped_ofm_value) - ofm.values = np.array(requantized_vals) + ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type()) + ofm.values.shape = input_values.shape # Case: Float input - quantize to int - elif np.issubdtype(input_values.dtype, np.float): + elif ifm.dtype.type == BaseType.Float: quantized_vals = [] for val in input_values: # Derive quantized value quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point - quantized_vals.append(quant_val) + clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max) + quantized_vals.append(clamped_quantized_val) # Pass the statically calculated quant val to output tensor - ofm.values = np.array(quantized_vals) + ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type()) + + # Unsupported data type + else: + return op # Make quantize op const and disconnect from parent node @@ -1493,23 +1498,6 @@ def tflite_optimise_graph(nng, arch): # Compile time optimisations optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor] - for optimisation in optimisation_list: - for idx, sg in enumerate(nng.subgraphs): - nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, - sg, - arch, - [], - [optimisation], - rewrite_unsupported=False, - ) - - # Pre-processing step - pre_process_list = [ - supported_operator_check, - set_ifm_ofm_op_shapes, - ] - for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( nng, @@ -1520,6 +1508,12 @@ def tflite_optimise_graph(nng, arch): rewrite_unsupported=False, ) + # Pre-processing step + pre_process_list = [ + supported_operator_check, + set_ifm_ofm_op_shapes, + ] + for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( nng, |