aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2022-07-06 13:42:24 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2022-07-13 15:18:20 +0000
commita04f2f7322e7b83d93e875313d2e5b4d0dca94fb (patch)
treec630418ccf8ca22b6ffae2ca9865967d7c7ac0f1
parentc4d35eb580902dfe6acedb2db3a72c32760f86af (diff)
downloadethos-u-vela-a04f2f7322e7b83d93e875313d2e5b4d0dca94fb.tar.gz
MLBEDSW-6687 Vela crashes in npu_serialisation.py and tflite_graph_optimiser.py
Fixed static optimisation of Quantize operator by running unsupported formats on CPU. Also added support for int16 and corrected the calculation. Change-Id: I861c712aa6258dba53fcf4d5dae45d1d416e6141 Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py56
1 files changed, 25 insertions, 31 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 10ddca60..f2a8c803 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -25,6 +25,7 @@ from . import fp_math
from . import rewrite_graph
from . import scaling
from .api import NpuRoundingMode
+from .data_type import BaseType
from .data_type import DataType
from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
@@ -1408,40 +1409,44 @@ def optimise_quantize(op: Operation, arch, nng):
input_values = np.array([input_values])
# requantized int8 to int8
- if ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8:
+ if (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) or (
+ ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16
+ ):
# scale needs to use double precision to match TFLite reference kernel
effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
effective_multiplier, effective_shift = quantise_scale(effective_scale)
- assert effective_shift >= 0
- assert -31 <= effective_shift <= 30
- round_val = 1 << (effective_shift - 1)
-
requantized_vals = []
- for val in input_values:
+ for val in input_values.flatten():
input_val = val - ifm.quantization.zero_point
- output = input_val * effective_multiplier + round_val
- ofm_val = (output >> effective_shift) + ofm.quantization.zero_point
+ ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
+ ofm_val += ofm.quantization.zero_point
- clamped_ofm_values = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
- requantized_vals.append(clamped_ofm_values)
+ clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
+ requantized_vals.append(clamped_ofm_value)
- ofm.values = np.array(requantized_vals)
+ ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
+ ofm.values.shape = input_values.shape
# Case: Float input - quantize to int
- elif np.issubdtype(input_values.dtype, np.float):
+ elif ifm.dtype.type == BaseType.Float:
quantized_vals = []
for val in input_values:
# Derive quantized value
quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
- quantized_vals.append(quant_val)
+ clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
+ quantized_vals.append(clamped_quantized_val)
# Pass the statically calculated quant val to output tensor
- ofm.values = np.array(quantized_vals)
+ ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
+
+ # Unsupported data type
+ else:
+ return op
# Make quantize op const and disconnect from parent node
@@ -1493,23 +1498,6 @@ def tflite_optimise_graph(nng, arch):
# Compile time optimisations
optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
- for optimisation in optimisation_list:
- for idx, sg in enumerate(nng.subgraphs):
- nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng,
- sg,
- arch,
- [],
- [optimisation],
- rewrite_unsupported=False,
- )
-
- # Pre-processing step
- pre_process_list = [
- supported_operator_check,
- set_ifm_ofm_op_shapes,
- ]
-
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
nng,
@@ -1520,6 +1508,12 @@ def tflite_optimise_graph(nng, arch):
rewrite_unsupported=False,
)
+ # Pre-processing step
+ pre_process_list = [
+ supported_operator_check,
+ set_ifm_ofm_op_shapes,
+ ]
+
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
nng,