aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2022-09-16 09:39:26 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2022-09-16 11:01:39 +0200
commit36424312fcc7c279a929073160ca7191a926e77b (patch)
tree138eaa1ad66c22110d2e1ff47e204ddbf30ad357
parent17e53b5d776109e1bd1073c657ff0453ccf3c09e (diff)
downloadethos-u-vela-36424312fcc7c279a929073160ca7191a926e77b.tar.gz
MLBEDSW-6938 Fix PReLU optimisation
Fixed PReLU optimisation to LeakyReLU with negative alpha. Added optimisation of LeakyReLU to ReLU when alpha is zero. Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com> Change-Id: I5e66f79b29908fffd95b6115799021138ebb401a
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py30
1 files changed, 18 insertions, 12 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index fb98e21b..0630ef41 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1019,11 +1019,12 @@ def convert_lrelu_to_mul_max(op, arch):
alpha = np.float32(op.attrs["alpha"])
use_mul_max = 0 < alpha < 1
+ is_converted_prelu = "alpha_scaling" in op.attrs
if use_mul_max:
mul_ifm = ifm
new_op = Op.Maximum
else:
- # Need to use a different approach for alpha > 1
+ # Need to use a different approach for alpha < 0 or alpha > 1
no_scale_quant = ifm.quantization.clone()
no_scale_quant.scale_f32 = None
no_scale_quant.zero_point = 0
@@ -1034,7 +1035,10 @@ def convert_lrelu_to_mul_max(op, arch):
min_op.add_input_tensor(ifm)
min_op.add_input_tensor(zero)
mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
- mul_ifm.dtype = DataType.int32
+ if alpha < 0 and not is_converted_prelu:
+ # For negative alpha that is not from a converted PReLU we need to use
+ # int32 Mul below to perform the (negative) alpha scaling
+ mul_ifm.dtype = DataType.int32
min_op.set_output_tensor(mul_ifm)
min_op.set_ifm_ofm_shapes()
new_op = Op.RescaleAdd
@@ -1050,8 +1054,8 @@ def convert_lrelu_to_mul_max(op, arch):
quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
quantization.zero_point = 0
alpha_dtype = mul_ifm.dtype
- if "alpha_scaling" in op.attrs:
- # The LeakyRelu was the result from convert_prelu
+ if is_converted_prelu:
+ # The LeakyRelu was the result from convert_prelu and the scaling is provided
scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
mul_alpha.type = Op.RescaleMul
mul_alpha.rescale = [alpha_scale, alpha_shift]
@@ -1062,7 +1066,7 @@ def convert_lrelu_to_mul_max(op, arch):
else:
quantization.scale_f32 = alpha
if alpha_dtype == DataType.int32:
- # When the datatype is int32 we need to do the scaling with the multiplication
+ # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
else:
scalar = 1
@@ -1179,14 +1183,16 @@ def convert_lrelu(op, arch, nng):
ifm, ofm = op.get_ifm_ofm()
if ifm is None or ofm is None:
return op
+ alpha = op.attrs["alpha"]
+ if alpha == 0:
+ # When alpha is 0 the opertion can be converted to a ReLU
+ op.type = Op.Relu
+ op.name = op.name.replace("LeakyRelu", op.type.name)
+ return op
if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
# use LUT for int8/uint8
return convert_lrelu_to_lut(op, arch)
- if (
- check_quantized_tens_scaling_equal(ifm, ofm)
- and ifm.dtype == ofm.dtype == DataType.int16
- and op.attrs["alpha"] >= 0
- ):
+ if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
# use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
return op
return convert_lrelu_to_mul_max(op, arch)
@@ -1830,6 +1836,8 @@ def tflite_optimise_graph(nng, arch):
convert_conv_to_fc,
convert_softmax,
convert_prelu,
+ convert_mul_max_to_abs_or_lrelu,
+ convert_lrelu,
optimise_strided_conv,
convert_hardswish_to_lut,
rewrite_fully_connected_input,
@@ -1840,8 +1848,6 @@ def tflite_optimise_graph(nng, arch):
fixup_resize,
fixup_bias_tensors,
fixup_asymmetric_weights,
- convert_mul_max_to_abs_or_lrelu,
- convert_lrelu,
convert_tanh_sigmoid_to_lut,
replace_pad_by_hw_pad,
]