diff options
author | Ayaan Masood <Ayaan.Masood@arm.com> | 2022-06-29 18:16:04 +0100 |
---|---|---|
committer | Ayaan Masood <Ayaan.Masood@arm.com> | 2022-06-29 18:16:04 +0100 |
commit | 25f48dd70aebeecd490de71eed3d4f7fbad1b121 (patch) | |
tree | 1cf03f59c8160a00a68faf0ffa62a9cd04a5c5b2 /ethosu/vela/test | |
parent | 4965faee41300393cd8d74da4b399fa4c4ee9030 (diff) | |
download | ethos-u-vela-25f48dd70aebeecd490de71eed3d4f7fbad1b121.tar.gz |
MLBEDSW-6314 Static optimisation for quantise OP
*Quantise op becomes constant if input is known at compile time
*Quantised values calculated if input of op is const and float
*Const inputs to quant op that are int are requantized
Change-Id: Ic94a72a392af709fe6a640d7dacbb5dc2334f16f
Signed-off-by: Ayaan Masood <Ayaan.Masood@arm.com>
Diffstat (limited to 'ethosu/vela/test')
-rw-r--r-- | ethosu/vela/test/test_graph_optimiser.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py index b8655c97..bfb1ded2 100644 --- a/ethosu/vela/test/test_graph_optimiser.py +++ b/ethosu/vela/test/test_graph_optimiser.py @@ -23,6 +23,7 @@ from ethosu.vela.data_type import DataType from ethosu.vela.graph_optimiser import optimise_graph from ethosu.vela.nn_graph import NetworkType from ethosu.vela.operation import Op +from ethosu.vela.operation import Operation from ethosu.vela.operation import Padding from ethosu.vela.rewrite_graph import verify_graph_health from ethosu.vela.tensor import create_const_tensor @@ -31,6 +32,7 @@ from ethosu.vela.tensor import Tensor from ethosu.vela.test import testutil from ethosu.vela.tflite_graph_optimiser import calc_explicit_padding from ethosu.vela.tflite_graph_optimiser import convert_batched_fc_shape +from ethosu.vela.tflite_graph_optimiser import optimise_quantize from ethosu.vela.tflite_graph_optimiser import replace_pad_by_hw_pad from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input @@ -533,3 +535,92 @@ def test_remove_expand_dims(): assert verify_graph_health(nng) nng = optimise_graph(nng, arch, NetworkType.TFLite, True) assert verify_graph_health(nng) + + +def test_quant_static_optimisations(): + + """ + Tests if the quant value at vela compile time is calculated correctly + """ + + quant_ifm = create_const_tensor( + "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8 + ) + quant_ifm.quantization = testutil.default_quant_params() + quant_ifm.quantization.scale_f32 = 0.15748031 + quant_ifm.quantization.quant_min = -128 + quant_ifm.quantization.quant_max = 127 + + quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8) + quant_ofm.quantization = testutil.default_quant_params() + quant_ofm.quantization.scale_f32 = 0.036092404 + quant_ofm.quantization.zero_point = -128 + quant_ofm.quantization.quant_min = -128 + quant_ofm.quantization.quant_max = 127 + + # Create quant op + + quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm) + + quant_op.run_on_npu = True + + op: Operation = optimise_quantize(quant_op, None, None) + + assert op.ofm.values == 127 + + quant_ifm = create_const_tensor( + "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8 + ) + quant_ifm.quantization = testutil.default_quant_params() + quant_ifm.quantization.scale_f32 = 0.15748031 + quant_ifm.quantization.quant_min = -128 + quant_ifm.quantization.quant_max = 127 + + quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8) + quant_ofm.quantization = testutil.default_quant_params() + quant_ofm.quantization.scale_f32 = 0.036092404 + quant_ofm.quantization.zero_point = -128 + quant_ofm.quantization.quant_min = -128 + quant_ofm.quantization.quant_max = 127 + + # Create quant op + + quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm) + + quant_op.run_on_npu = True + + op: Operation = optimise_quantize(quant_op, None, None) + + assert op.ofm.values == 127 + + +def test_optimise_quantize_multiple_values(): + """ + Tests if the quant value at vela compile time is calculated correctly + when passing multiple values to quantize node + """ + + quant_ifm = create_const_tensor( + "const_quant_ifm", values=np.array([127, 127]), value_dtype=np.int8, shape=[], dtype=DataType.int8 + ) + quant_ifm.quantization = testutil.default_quant_params() + quant_ifm.quantization.scale_f32 = 0.15748031 + quant_ifm.quantization.quant_min = -128 + quant_ifm.quantization.quant_max = 127 + + quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8) + quant_ofm.quantization = testutil.default_quant_params() + quant_ofm.quantization.scale_f32 = 0.036092404 + quant_ofm.quantization.zero_point = -128 + quant_ofm.quantization.quant_min = -128 + quant_ofm.quantization.quant_max = 127 + + # Create quant op + + quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm) + + quant_op.run_on_npu = True + + op: Operation = optimise_quantize(quant_op, None, None) + + assert (op.ofm.values == np.array([127, 127])).all() |