aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_graph_optimiser.py
diff options
context:
space:
mode:
authorAyaan Masood <Ayaan.Masood@arm.com>2022-06-29 18:16:04 +0100
committerAyaan Masood <Ayaan.Masood@arm.com>2022-06-29 18:16:04 +0100
commit25f48dd70aebeecd490de71eed3d4f7fbad1b121 (patch)
tree1cf03f59c8160a00a68faf0ffa62a9cd04a5c5b2 /ethosu/vela/test/test_graph_optimiser.py
parent4965faee41300393cd8d74da4b399fa4c4ee9030 (diff)
downloadethos-u-vela-25f48dd70aebeecd490de71eed3d4f7fbad1b121.tar.gz
MLBEDSW-6314 Static optimisation for quantise OP
*Quantise op becomes constant if input is known at compile time *Quantised values calculated if input of op is const and float *Const inputs to quant op that are int are requantized Change-Id: Ic94a72a392af709fe6a640d7dacbb5dc2334f16f Signed-off-by: Ayaan Masood <Ayaan.Masood@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_graph_optimiser.py')
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py91
1 files changed, 91 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index b8655c97..bfb1ded2 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -23,6 +23,7 @@ from ethosu.vela.data_type import DataType
from ethosu.vela.graph_optimiser import optimise_graph
from ethosu.vela.nn_graph import NetworkType
from ethosu.vela.operation import Op
+from ethosu.vela.operation import Operation
from ethosu.vela.operation import Padding
from ethosu.vela.rewrite_graph import verify_graph_health
from ethosu.vela.tensor import create_const_tensor
@@ -31,6 +32,7 @@ from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
from ethosu.vela.tflite_graph_optimiser import calc_explicit_padding
from ethosu.vela.tflite_graph_optimiser import convert_batched_fc_shape
+from ethosu.vela.tflite_graph_optimiser import optimise_quantize
from ethosu.vela.tflite_graph_optimiser import replace_pad_by_hw_pad
from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input
@@ -533,3 +535,92 @@ def test_remove_expand_dims():
assert verify_graph_health(nng)
nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
assert verify_graph_health(nng)
+
+
+def test_quant_static_optimisations():
+
+ """
+ Tests if the quant value at vela compile time is calculated correctly
+ """
+
+ quant_ifm = create_const_tensor(
+ "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
+ )
+ quant_ifm.quantization = testutil.default_quant_params()
+ quant_ifm.quantization.scale_f32 = 0.15748031
+ quant_ifm.quantization.quant_min = -128
+ quant_ifm.quantization.quant_max = 127
+
+ quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
+ quant_ofm.quantization = testutil.default_quant_params()
+ quant_ofm.quantization.scale_f32 = 0.036092404
+ quant_ofm.quantization.zero_point = -128
+ quant_ofm.quantization.quant_min = -128
+ quant_ofm.quantization.quant_max = 127
+
+ # Create quant op
+
+ quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
+
+ quant_op.run_on_npu = True
+
+ op: Operation = optimise_quantize(quant_op, None, None)
+
+ assert op.ofm.values == 127
+
+ quant_ifm = create_const_tensor(
+ "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
+ )
+ quant_ifm.quantization = testutil.default_quant_params()
+ quant_ifm.quantization.scale_f32 = 0.15748031
+ quant_ifm.quantization.quant_min = -128
+ quant_ifm.quantization.quant_max = 127
+
+ quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
+ quant_ofm.quantization = testutil.default_quant_params()
+ quant_ofm.quantization.scale_f32 = 0.036092404
+ quant_ofm.quantization.zero_point = -128
+ quant_ofm.quantization.quant_min = -128
+ quant_ofm.quantization.quant_max = 127
+
+ # Create quant op
+
+ quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
+
+ quant_op.run_on_npu = True
+
+ op: Operation = optimise_quantize(quant_op, None, None)
+
+ assert op.ofm.values == 127
+
+
+def test_optimise_quantize_multiple_values():
+ """
+ Tests if the quant value at vela compile time is calculated correctly
+ when passing multiple values to quantize node
+ """
+
+ quant_ifm = create_const_tensor(
+ "const_quant_ifm", values=np.array([127, 127]), value_dtype=np.int8, shape=[], dtype=DataType.int8
+ )
+ quant_ifm.quantization = testutil.default_quant_params()
+ quant_ifm.quantization.scale_f32 = 0.15748031
+ quant_ifm.quantization.quant_min = -128
+ quant_ifm.quantization.quant_max = 127
+
+ quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
+ quant_ofm.quantization = testutil.default_quant_params()
+ quant_ofm.quantization.scale_f32 = 0.036092404
+ quant_ofm.quantization.zero_point = -128
+ quant_ofm.quantization.quant_min = -128
+ quant_ofm.quantization.quant_max = 127
+
+ # Create quant op
+
+ quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
+
+ quant_op.run_on_npu = True
+
+ op: Operation = optimise_quantize(quant_op, None, None)
+
+ assert (op.ofm.values == np.array([127, 127])).all()