aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAyaan Masood <Ayaan.Masood@arm.com>2022-06-29 18:16:04 +0100
committerAyaan Masood <Ayaan.Masood@arm.com>2022-06-29 18:16:04 +0100
commit25f48dd70aebeecd490de71eed3d4f7fbad1b121 (patch)
tree1cf03f59c8160a00a68faf0ffa62a9cd04a5c5b2
parent4965faee41300393cd8d74da4b399fa4c4ee9030 (diff)
downloadethos-u-vela-25f48dd70aebeecd490de71eed3d4f7fbad1b121.tar.gz
MLBEDSW-6314 Static optimisation for quantise OP
*Quantise op becomes constant if input is known at compile time *Quantised values calculated if input of op is const and float *Const inputs to quant op that are int are requantized Change-Id: Ic94a72a392af709fe6a640d7dacbb5dc2334f16f Signed-off-by: Ayaan Masood <Ayaan.Masood@arm.com>
-rw-r--r--SUPPORTED_OPS.md6
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py91
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py80
-rw-r--r--ethosu/vela/tflite_model_semantic.py7
4 files changed, 178 insertions, 6 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index c258dfbd..83429b7a 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -58,10 +58,10 @@ Please check the supported operator list for your chosen runtime for further inf
This is a list of constraints most NPU operators must satisfy in order to be scheduled on the NPU.
(Operators excluded from certain constraints are shown in brackets [ ] )
-- Input(s) and Output tensors must not be dynamic
+- Input(s) and Output tensors must not be dynamic - [Quantize]
- Input(s) and Output tensors must have a defined shape
-- Output tensors cannot be scalar
-- Scalar Input tensors are only valid for op type: ADD, EXPAND_DIMS, MAXIMUM, MEAN, MINIMUM, MUL, SPLIT, SPLIT_V, SUB
+- Output tensors cannot be scalar - [Quantize]
+- Scalar Input tensors are only valid for op type: ADD, EXPAND_DIMS, MAXIMUM, MEAN, MINIMUM, MUL, SPLIT, SPLIT_V, SUB - [Quantize]
- Input(s) and Output tensors must not be greater than 4D
- Input(s), Output and Weight tensors must have quantization parameters - [Shape]
- Input(s), Output and Weight tensors with quantization scales must be finite - [Shape]
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index b8655c97..bfb1ded2 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -23,6 +23,7 @@ from ethosu.vela.data_type import DataType
from ethosu.vela.graph_optimiser import optimise_graph
from ethosu.vela.nn_graph import NetworkType
from ethosu.vela.operation import Op
+from ethosu.vela.operation import Operation
from ethosu.vela.operation import Padding
from ethosu.vela.rewrite_graph import verify_graph_health
from ethosu.vela.tensor import create_const_tensor
@@ -31,6 +32,7 @@ from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
from ethosu.vela.tflite_graph_optimiser import calc_explicit_padding
from ethosu.vela.tflite_graph_optimiser import convert_batched_fc_shape
+from ethosu.vela.tflite_graph_optimiser import optimise_quantize
from ethosu.vela.tflite_graph_optimiser import replace_pad_by_hw_pad
from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input
@@ -533,3 +535,92 @@ def test_remove_expand_dims():
assert verify_graph_health(nng)
nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
assert verify_graph_health(nng)
+
+
+def test_quant_static_optimisations():
+
+ """
+ Tests if the quant value at vela compile time is calculated correctly
+ """
+
+ quant_ifm = create_const_tensor(
+ "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
+ )
+ quant_ifm.quantization = testutil.default_quant_params()
+ quant_ifm.quantization.scale_f32 = 0.15748031
+ quant_ifm.quantization.quant_min = -128
+ quant_ifm.quantization.quant_max = 127
+
+ quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
+ quant_ofm.quantization = testutil.default_quant_params()
+ quant_ofm.quantization.scale_f32 = 0.036092404
+ quant_ofm.quantization.zero_point = -128
+ quant_ofm.quantization.quant_min = -128
+ quant_ofm.quantization.quant_max = 127
+
+ # Create quant op
+
+ quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
+
+ quant_op.run_on_npu = True
+
+ op: Operation = optimise_quantize(quant_op, None, None)
+
+ assert op.ofm.values == 127
+
+ quant_ifm = create_const_tensor(
+ "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
+ )
+ quant_ifm.quantization = testutil.default_quant_params()
+ quant_ifm.quantization.scale_f32 = 0.15748031
+ quant_ifm.quantization.quant_min = -128
+ quant_ifm.quantization.quant_max = 127
+
+ quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
+ quant_ofm.quantization = testutil.default_quant_params()
+ quant_ofm.quantization.scale_f32 = 0.036092404
+ quant_ofm.quantization.zero_point = -128
+ quant_ofm.quantization.quant_min = -128
+ quant_ofm.quantization.quant_max = 127
+
+ # Create quant op
+
+ quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
+
+ quant_op.run_on_npu = True
+
+ op: Operation = optimise_quantize(quant_op, None, None)
+
+ assert op.ofm.values == 127
+
+
+def test_optimise_quantize_multiple_values():
+ """
+ Tests if the quant value at vela compile time is calculated correctly
+ when passing multiple values to quantize node
+ """
+
+ quant_ifm = create_const_tensor(
+ "const_quant_ifm", values=np.array([127, 127]), value_dtype=np.int8, shape=[], dtype=DataType.int8
+ )
+ quant_ifm.quantization = testutil.default_quant_params()
+ quant_ifm.quantization.scale_f32 = 0.15748031
+ quant_ifm.quantization.quant_min = -128
+ quant_ifm.quantization.quant_max = 127
+
+ quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
+ quant_ofm.quantization = testutil.default_quant_params()
+ quant_ofm.quantization.scale_f32 = 0.036092404
+ quant_ofm.quantization.zero_point = -128
+ quant_ofm.quantization.quant_min = -128
+ quant_ofm.quantization.quant_max = 127
+
+ # Create quant op
+
+ quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
+
+ quant_op.run_on_npu = True
+
+ op: Operation = optimise_quantize(quant_op, None, None)
+
+ assert (op.ofm.values == np.array([127, 127])).all()
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index cf3985e4..10ddca60 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -49,6 +49,7 @@ from .operation import Operation
from .operation import Padding
from .operation_util import create_avgpool_nop
from .operation_util import get_pad_values_from_input
+from .scaling import quantise_scale
from .shape4d import Shape4D
from .softmax import SoftMax
from .tensor import check_quantized_tens_scaling_equal
@@ -1391,6 +1392,71 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
return op
+def optimise_quantize(op: Operation, arch, nng):
+
+ if op.type == Op.Quantize and op.run_on_npu:
+
+ ifm, ofm = op.get_ifm_ofm()
+ input_values = ifm.values
+
+ # Guard clause - input not const or no values to quantize
+ if ifm.ops[0].type != Op.Const or input_values is None:
+ return op
+
+ # Singular val in numpy array, convert to indexable array
+ if input_values.ndim == 0:
+ input_values = np.array([input_values])
+
+ # requantized int8 to int8
+ if ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8:
+
+ # scale needs to use double precision to match TFLite reference kernel
+ effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
+ effective_multiplier, effective_shift = quantise_scale(effective_scale)
+
+ assert effective_shift >= 0
+ assert -31 <= effective_shift <= 30
+ round_val = 1 << (effective_shift - 1)
+
+ requantized_vals = []
+ for val in input_values:
+ input_val = val - ifm.quantization.zero_point
+
+ output = input_val * effective_multiplier + round_val
+ ofm_val = (output >> effective_shift) + ofm.quantization.zero_point
+
+ clamped_ofm_values = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
+ requantized_vals.append(clamped_ofm_values)
+
+ ofm.values = np.array(requantized_vals)
+
+ # Case: Float input - quantize to int
+ elif np.issubdtype(input_values.dtype, np.float):
+
+ quantized_vals = []
+ for val in input_values:
+
+ # Derive quantized value
+ quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
+ quantized_vals.append(quant_val)
+
+ # Pass the statically calculated quant val to output tensor
+ ofm.values = np.array(quantized_vals)
+
+ # Make quantize op const and disconnect from parent node
+
+ # Remove reference of the current quant op from the parent tensor's consumer list
+ ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
+
+ # Clear any references to parent node
+ op.inputs = []
+
+ # Convert this quantize op to const
+ op.type = Op.Const
+
+ return op
+
+
def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
"""Static optimisation for SHAPE operator output value known at compile time"""
@@ -1424,9 +1490,19 @@ def supported_operator_check(op, arch, nng):
def tflite_optimise_graph(nng, arch):
-
# Compile time optimisations
- optimisation_list = [convert_shape_op_to_constant_tensor]
+ optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
+
+ for optimisation in optimisation_list:
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng,
+ sg,
+ arch,
+ [],
+ [optimisation],
+ rewrite_unsupported=False,
+ )
# Pre-processing step
pre_process_list = [
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index ee66d4cc..9408e0ce 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -216,7 +216,12 @@ class TFLiteSemantic:
TFLiteSemantic.constraint_tens_quant_none_check,
TFLiteSemantic.constraint_tens_quant_scale,
TFLiteSemantic.constraint_quant_scale_inf,
- ]
+ ],
+ Op.Quantize: [
+ TFLiteSemantic.constraint_tens_no_dynamic,
+ TFLiteSemantic.constraint_tens_output_scalar,
+ TFLiteSemantic.constraint_tens_input_scalar,
+ ],
}
return generic_constraints_exclude_list