From 8ddd4899892dace88306b3b155dbf47cc47fa4cd Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Fri, 19 Aug 2022 16:06:04 +0200 Subject: MLBEDSW-6832 PReLU support in Vela Added PReLU support in graph optimiser. Signed-off-by: Fredrik Svedberg Change-Id: I3a188675e3edcdf0b4a4bfcdd134fda0bf8a560f --- ethosu/vela/operation.py | 2 +- ethosu/vela/tflite_graph_optimiser.py | 53 +++++++++++++++++++++++++++++++ ethosu/vela/tflite_mapping.py | 2 +- ethosu/vela/tflite_supported_operators.py | 10 +++++- 4 files changed, 64 insertions(+), 3 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 47f4fe0f..54e823a8 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -229,7 +229,7 @@ class Op(Enum): PadV2 = OperatorInfo() Placeholder = OperatorInfo() # Only used in CPU subgraphs Pow = OperatorInfo() - Prelu = OperatorInfo() + Prelu = OperatorInfo(indices=NNG_IFM_IFM2_INDICES) Prod = OperatorInfo() Quantize = OperatorInfo(indices=NNG_IFM_INDICES) QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES) diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index ed8fa1e3..3646b01e 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -746,6 +746,58 @@ def convert_softmax(op, arch, nng): return op +def convert_prelu(op, arch, nng): + if op.type == Op.Prelu: + ifm, alpha, ofm = op.get_ifm_ifm2_ofm() + if None in (ifm, alpha, ofm): + return op + + no_scale_quant = ifm.quantization.clone() + no_scale_quant.scale_f32 = None + no_scale_quant.zero_point = 0 + zero = create_const_tensor("zero_const", [1, 1, 1, 1], ifm.dtype, [0], quantization=no_scale_quant) + + # Select values < 0 + min_op = Operation(Op.Minimum, op.name + "_min") + min_op.add_input_tensor(ifm) + min_op.add_input_tensor(zero) + fm_negative = ifm.clone(op.name + "_negative", set_unique=True) + min_op.set_output_tensor(fm_negative) + min_op.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, min_op) + + # and multiply with alpha tensor + mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha") + mul_alpha.add_input_tensor(fm_negative) + mul_alpha.add_input_tensor(alpha) + fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True) + mul_alpha.set_output_tensor(fm_alpha) + mul_alpha.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, mul_alpha) + + # Select (and scale) values > 0 + relu_op = Operation(Op.Relu, op.name + "_relu") + relu_op.add_input_tensor(ifm) + fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True) + relu_op.set_output_tensor(fm_scaled) + relu_op.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, relu_op) + + # Add scaled and alpha multiplied values (without scaling) + add_op = Operation(Op.RescaleAdd, op.name + "_add") + add_op.rescale = (1, 0) # No scale or shift + add_op.add_input_tensor(fm_alpha) + add_op.add_input_tensor(fm_scaled) + add_op.set_output_tensor(ofm) + add_op.set_ifm_ofm_shapes() + + DebugDatabase.add_optimised(op, add_op) + ifm.consumer_list.remove(op) + op = add_op + + return op + + def convert_mul_max_to_abs_or_lrelu(op, arch, nng): r"""Whenever there is a subgraph with this topology: @@ -1648,6 +1700,7 @@ def tflite_optimise_graph(nng, arch): convert_depthwise_to_conv, convert_conv_to_fc, convert_softmax, + convert_prelu, optimise_strided_conv, convert_hardswish_to_lut, rewrite_fully_connected_input, diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index c515d23a..3ccedc73 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -732,7 +732,7 @@ builtin_operator_map = { ), TFLITE_NO_INDICES, ), - BuiltinOperator.PRELU: (Op.Prelu, None, TFLITE_NO_INDICES), + BuiltinOperator.PRELU: (Op.Prelu, None, TFLITE_IFM_IFM2_INDICES), BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES), BuiltinOperator.ARG_MAX: ( Op.ArgMax, diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 5d25e37b..1915d43b 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -123,7 +123,15 @@ class TFLiteSupportedOperators: Op.Clip, ) ) - activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish)) + activation_ops = relu_ops | set( + ( + Op.Tanh, + Op.Sigmoid, + Op.Softmax, + Op.HardSwish, + Op.Prelu, + ) + ) npu_post_ops = ( # activation functions activation_ops -- cgit v1.2.1