aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-08-25 13:36:41 +0200
committertim.hall <tim.hall@arm.com>2020-08-28 16:48:54 +0000
commitd7911c44323f2704157cfde6e413136b070f5d4b (patch)
tree9983d34f204a17a6e4d094909f2222e9de828997
parent7201932246734b8b5db016106ad8df108d2513d0 (diff)
downloadethos-u-vela-d7911c44323f2704157cfde6e413136b070f5d4b.tar.gz
MLBEDSW-2688: LUT calculation with different in/out scale
Enables LUT for LeakyRelu with int8/uint8 even if input scale is different from the output scale. Fusing LUT with a previous operator for this situation requires further work. Change-Id: I9eddfe36f457e763d44eb3e05fbe240eac7cfec9 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--ethosu/vela/fp_math.py10
-rw-r--r--ethosu/vela/graph_optimiser.py63
-rw-r--r--ethosu/vela/register_command_stream_generator.py33
-rw-r--r--ethosu/vela/test/test_fp_math.py37
4 files changed, 115 insertions, 28 deletions
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py
index 2055879a..eaeb84a1 100644
--- a/ethosu/vela/fp_math.py
+++ b/ethosu/vela/fp_math.py
@@ -136,3 +136,13 @@ def exp_on_negative_values(a):
return np.iinfo(np.int32).max
else:
return result
+
+
+def multiply_by_quantized_multiplier(x, scale, shift):
+ # Multiplies x (int32) by (scale, shift) which have obtained by a call to scaling.quantize_scale,
+ # returns rounded result
+ shift = 31 - shift
+ left_shift = shift if shift > 0 else 0
+ right_shift = -shift if shift < 0 else 0
+ mul = saturating_rounding_mul(x * (1 << left_shift), scale)
+ return rounding_divide_by_pot(mul, right_shift)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index aaccce2c..7ab009f0 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -20,8 +20,10 @@ import math
import numpy as np
+from . import fp_math
from . import lut
from . import rewrite_graph
+from . import scaling
from .data_type import DataType
from .errors import UnsupportedFeatureError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
@@ -637,7 +639,8 @@ def convert_mul_max_to_abs_or_lrelu(op, arch):
return op
# make sure the Mul doesn't have any other consumers
- if len(mul.outputs[0].consumers()) != 1:
+ mul_ofm = mul.outputs[0]
+ if len(mul_ofm.consumers()) != 1:
return op
# make sure the Mul doesn't have a faf
if mul.attrs["fused_activation_function"]:
@@ -645,7 +648,7 @@ def convert_mul_max_to_abs_or_lrelu(op, arch):
ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
return op
- if not ifm.is_scaling_equal(ofm):
+ if not ifm.is_scaling_equal(ofm) or not ifm.is_scaling_equal(mul_ofm):
# rewrite to LeakyRelu currently only makes sense if the quantization is identical
return op
@@ -671,6 +674,15 @@ def convert_mul_max_to_abs_or_lrelu(op, arch):
if val >= 0:
new_op = "LeakyRelu"
op.attrs["alpha"] = val
+ # to produce bit exact results, the alpha is not enough;
+ # save additional scaling info in attr "alpha_scale", to be used as input
+ # to the LUT construction
+ alpha_scalar = const_tens.quant_values - const_tens.quantization.zero_point
+ mul_ifm_scale = np.double(ifm.quantization.scale_f32)
+ mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
+ mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
+ alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
+ op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
elif val == -1:
new_op = "Abs"
else:
@@ -744,15 +756,39 @@ def convert_lrelu_to_lut(op, arch):
op.attrs["is_nop"] = True
# Create an input tensor containing scalar zero
quantization = QuantizationParameters(0.0, 255.0)
- quantization.scale_f32 = 1.0
+ quantization.scale_f32 = ifm.quantization.scale_f32
quantization.zero_point = 0
tens = create_const_tensor(op.inputs[0].name + "_add", [], ifm.dtype, [0], np.uint8, quantization=quantization)
op.add_input_tensor(tens)
- alpha = op.attrs["alpha"]
- zp = ofm.quantization.zero_point
# Generate the LUT
+ alpha = op.attrs["alpha"]
+ ifm_scale = np.double(ifm.quantization.scale_f32)
+ ofm_scale = np.double(ofm.quantization.scale_f32)
+ zp_in = ifm.quantization.zero_point
+ zp_out = ofm.quantization.zero_point
+ identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
+ alpha_scalar = 1
+ alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
+ if "alpha_scaling" in op.attrs:
+ # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
+ alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
+ values = []
ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
- values = [int(x) if x >= zp else int(round(zp - alpha * (zp - x))) for x in ix]
+ quantized_min = min(ix)
+ quantized_max = max(ix)
+ for x in ix:
+ if x < zp_in:
+ lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
+ alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
+ )
+ else:
+ lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
+ lut_result = min(quantized_max, max(quantized_min, lut_result))
+ values.append(lut_result)
+ # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
+ # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
+ # should be the same as the IFM
+ op.attrs["forced_output_quantization"] = ifm.quantization
lut_tensor = lut.create_lut_tensor(op.name + "_lut", values, DataType.int8)
op.set_activation_lut(lut_tensor)
return op
@@ -763,13 +799,12 @@ def convert_lrelu(op, arch):
if op.type != "LeakyRelu":
return op
ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
- if ifm.is_scaling_equal(ofm) and ifm.dtype == ofm.dtype:
- if ifm.dtype in (DataType.uint8, DataType.int8):
- # use LUT
- return convert_lrelu_to_lut(op, arch)
- elif ifm.dtype == DataType.int16:
- # use LeakyRelu unmodified
- return op
+ if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
+ # use LUT for int8/uint8
+ return convert_lrelu_to_lut(op, arch)
+ if ifm.is_scaling_equal(ofm) and ifm.dtype == ofm.dtype and ifm.dtype == DataType.int16:
+ # use LeakyRelu unmodified for int16 with equal input/output scaling
+ return op
return convert_lrelu_to_mul_max(op, arch)
@@ -802,7 +837,7 @@ def fuse_activation_function_with_prev(op, arch):
if not fuse:
return op
# Move the fused activation function + corresponding info to prev_op
- for attr in ("fused_activation_function", "alpha"):
+ for attr in ("fused_activation_function", "alpha", "forced_output_quantization"):
if attr in op.attrs:
prev_op.attrs[attr] = op.attrs[attr]
if op.activation_lut is not None:
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 8d9f9185..609fcc6b 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -442,6 +442,13 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
fmf = primary_op.attrs.get("fused_memory_function", None)
faf = primary_op.attrs.get("fused_activation_function", None)
fused_quantize = any(op.type == "Quantize" for op in ps.ops)
+ # Force output scale, used in operations with fused LUT
+ # Note: with current LUT support, forced_ofm_quantization is always equal to cmd.ofm_tensor.quantization
+ # except when primary_op is AddAct + 0 (no-op) + LUT
+ forced_ofm_quantization = primary_op.attrs.get("forced_output_quantization", None)
+ ofm_quant = cmd.ofm_tensor.quantization
+ if forced_ofm_quantization is not None:
+ ofm_quant = forced_ofm_quantization
# Specifies which operand to apply scaling to in bitexact elementwise ADD/SUB
op_to_scale = 0
@@ -476,7 +483,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
if primary_op.type in set(("AddAct", "MulAct", "SubAct",)):
input_scale = cmd.ifm_tensor.quantization.scale_f32
input2_scale = cmd.ifm2_tensor.quantization.scale_f32
- output_scale = cmd.ofm_tensor.quantization.scale_f32
+ output_scale = ofm_quant.scale_f32
use_global_scale = True
if output_scale is not None and faf in ("Sigmoid", "Tanh"):
@@ -491,7 +498,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
else: # AddAct/SubAct
# Force output scale same as the input scale for
- # resizebiliner 1x1 that is converted to add
+ # resizebilinear 1x1 that is converted to add
if "resizebilinear" in primary_op.attrs:
output_scale = input2_scale
@@ -529,7 +536,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
elif primary_op.type in set(("LeakyRelu", "Abs",)):
- output_scale = cmd.ofm_tensor.quantization.scale_f32
+ output_scale = ofm_quant.scale_f32
use_global_scale = True
if primary_op.type == "LeakyRelu":
@@ -664,7 +671,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
elif fused_quantize:
# Quantize op requires different scaling
ifm_scale_f64 = np.double(cmd.ifm_tensor.quantization.scale_f32)
- ofm_scale_f64 = np.double(cmd.ofm_tensor.quantization.scale_f32)
+ ofm_scale_f64 = np.double(ofm_quant.scale_f32)
scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64)
elif primary_op.type == "ResizeBilinear" and "rescale" in primary_op.attrs:
rescale = primary_op.attrs["rescale"]
@@ -676,11 +683,8 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
# k_height == k_width == 1 is allways true in this case
# Normally the scale is maximised, to get maximum precision, which means that
# if rescale != 1, scale need to consider the number of bits needed for rescaling
- if None not in (
- cmd.ofm_tensor.quantization.scale_f32,
- cmd.ifm_tensor.quantization.scale_f32,
- ):
- rescale = cmd.ifm_tensor.quantization.scale_f32 / cmd.ofm_tensor.quantization.scale_f32
+ if None not in (ofm_quant.scale_f32, cmd.ifm_tensor.quantization.scale_f32,):
+ rescale = cmd.ifm_tensor.quantization.scale_f32 / ofm_quant.scale_f32
rescale_bits = 0
if k_height == k_width == 1:
if fmf == "ConcatSliceWrite":
@@ -797,9 +801,8 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
scale_region = base_ptr_idx_map[cmd.scale_tensor.mem_type]
emit.cmd0_with_param(cmd0.NPU_SET_SCALE_REGION, scale_region)
- ofm_quant = cmd.ofm_tensor.quantization
- ofm_quant_qmin = cmd.ofm_tensor.quantization.quant_min
- ofm_quant_qmax = cmd.ofm_tensor.quantization.quant_max
+ ofm_quant_qmin = ofm_quant.quant_min
+ ofm_quant_qmax = ofm_quant.quant_max
ifm_min = cmd.ifm_tensor.quantization.min
ifm_max = cmd.ifm_tensor.quantization.max
@@ -912,13 +915,15 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
emit.cmd0_with_param(zero_point_op, 0)
else:
assert tens.quantization.zero_point is not None, "need an actual zero point set"
- if (
+ if cmd0.NPU_SET_OFM_ZERO_POINT == zero_point_op and forced_ofm_quantization is not None:
+ zero_point = forced_ofm_quantization.zero_point
+ elif (
"resizebilinear" in primary_op.attrs
and primary_op.type == "AddAct"
and cmd0.NPU_SET_OFM_ZERO_POINT == zero_point_op
):
# Force output zero point same as the input zero point
- # for resizebiliner 1x1 that is converted to add
+ # for resizebilinear 1x1 that is converted to add
zero_point = cmd.ifm2_tensor.quantization.zero_point
else:
zero_point = tens.quantization.zero_point
diff --git a/ethosu/vela/test/test_fp_math.py b/ethosu/vela/test/test_fp_math.py
index 2dde1e4b..8c1ed679 100644
--- a/ethosu/vela/test/test_fp_math.py
+++ b/ethosu/vela/test/test_fp_math.py
@@ -19,6 +19,7 @@ import numpy as np
import pytest
from ethosu.vela import fp_math
+from ethosu.vela import scaling
from ethosu.vela.softmax import SoftMax
# Turn off black formatting for EXP_LUT to keep it compact
@@ -116,3 +117,39 @@ def test_exp():
sm = SoftMax(None)
for (expected, actual) in zip(EXP_LUT, sm.generate_exp_table(1.0, np.float32(0.05123165))):
assert actual == expected
+
+
+multiply_test_data = [
+ (0, 0, 0),
+ (0, 0.7, 0),
+ (0, 55.8, 0),
+ (6, 0.3, 2),
+ (200, 0, 0),
+ (1, 1, 1),
+ (1, 0.1, 0),
+ (1, 3.49, 3),
+ (1, 3.51, 4),
+ (27, 1, 27),
+ (13, 0.9, 12),
+ (3, 21.2, 64),
+ (1000, 2000, 2000000),
+ (32767, 32767, 32767 * 32767), # extreme values
+]
+
+
+@pytest.mark.parametrize("x, factor, expected", multiply_test_data)
+def test_multiply_by_quantized_multiplier(x, factor, expected):
+ scale, shift = scaling.quantise_scale(factor)
+ assert fp_math.multiply_by_quantized_multiplier(x, scale, shift) == expected
+ assert fp_math.multiply_by_quantized_multiplier(-x, scale, shift) == -expected
+ assert fp_math.multiply_by_quantized_multiplier(x, -scale, shift) == -expected
+ assert fp_math.multiply_by_quantized_multiplier(-x, -scale, shift) == expected
+
+
+def test_multiply_by_quantized_multiplier_int16_limits():
+ # Tests min/max limits of foreseen practical usage of multiply_by_quantized_multiplier
+ # for the purpose of calculating LUTs
+ for x in [-32768, 32767]:
+ for y in [-32768, 32767]:
+ scale, shift = scaling.quantise_scale(y)
+ assert fp_math.multiply_by_quantized_multiplier(x, scale, shift) == x * y