diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index ed8fa1e3..3646b01e 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -746,6 +746,58 @@ def convert_softmax(op, arch, nng): return op +def convert_prelu(op, arch, nng): + if op.type == Op.Prelu: + ifm, alpha, ofm = op.get_ifm_ifm2_ofm() + if None in (ifm, alpha, ofm): + return op + + no_scale_quant = ifm.quantization.clone() + no_scale_quant.scale_f32 = None + no_scale_quant.zero_point = 0 + zero = create_const_tensor("zero_const", [1, 1, 1, 1], ifm.dtype, [0], quantization=no_scale_quant) + + # Select values < 0 + min_op = Operation(Op.Minimum, op.name + "_min") + min_op.add_input_tensor(ifm) + min_op.add_input_tensor(zero) + fm_negative = ifm.clone(op.name + "_negative", set_unique=True) + min_op.set_output_tensor(fm_negative) + min_op.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, min_op) + + # and multiply with alpha tensor + mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha") + mul_alpha.add_input_tensor(fm_negative) + mul_alpha.add_input_tensor(alpha) + fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True) + mul_alpha.set_output_tensor(fm_alpha) + mul_alpha.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, mul_alpha) + + # Select (and scale) values > 0 + relu_op = Operation(Op.Relu, op.name + "_relu") + relu_op.add_input_tensor(ifm) + fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True) + relu_op.set_output_tensor(fm_scaled) + relu_op.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, relu_op) + + # Add scaled and alpha multiplied values (without scaling) + add_op = Operation(Op.RescaleAdd, op.name + "_add") + add_op.rescale = (1, 0) # No scale or shift + add_op.add_input_tensor(fm_alpha) + add_op.add_input_tensor(fm_scaled) + add_op.set_output_tensor(ofm) + add_op.set_ifm_ofm_shapes() + + DebugDatabase.add_optimised(op, add_op) + ifm.consumer_list.remove(op) + op = add_op + + return op + + def convert_mul_max_to_abs_or_lrelu(op, arch, nng): r"""Whenever there is a subgraph with this topology: @@ -1648,6 +1700,7 @@ def tflite_optimise_graph(nng, arch): convert_depthwise_to_conv, convert_conv_to_fc, convert_softmax, + convert_prelu, optimise_strided_conv, convert_hardswish_to_lut, rewrite_fully_connected_input, |