aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index ed8fa1e3..3646b01e 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -746,6 +746,58 @@ def convert_softmax(op, arch, nng):
return op
+def convert_prelu(op, arch, nng):
+ if op.type == Op.Prelu:
+ ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
+ if None in (ifm, alpha, ofm):
+ return op
+
+ no_scale_quant = ifm.quantization.clone()
+ no_scale_quant.scale_f32 = None
+ no_scale_quant.zero_point = 0
+ zero = create_const_tensor("zero_const", [1, 1, 1, 1], ifm.dtype, [0], quantization=no_scale_quant)
+
+ # Select values < 0
+ min_op = Operation(Op.Minimum, op.name + "_min")
+ min_op.add_input_tensor(ifm)
+ min_op.add_input_tensor(zero)
+ fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
+ min_op.set_output_tensor(fm_negative)
+ min_op.set_ifm_ofm_shapes()
+ DebugDatabase.add_optimised(op, min_op)
+
+ # and multiply with alpha tensor
+ mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
+ mul_alpha.add_input_tensor(fm_negative)
+ mul_alpha.add_input_tensor(alpha)
+ fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
+ mul_alpha.set_output_tensor(fm_alpha)
+ mul_alpha.set_ifm_ofm_shapes()
+ DebugDatabase.add_optimised(op, mul_alpha)
+
+ # Select (and scale) values > 0
+ relu_op = Operation(Op.Relu, op.name + "_relu")
+ relu_op.add_input_tensor(ifm)
+ fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
+ relu_op.set_output_tensor(fm_scaled)
+ relu_op.set_ifm_ofm_shapes()
+ DebugDatabase.add_optimised(op, relu_op)
+
+ # Add scaled and alpha multiplied values (without scaling)
+ add_op = Operation(Op.RescaleAdd, op.name + "_add")
+ add_op.rescale = (1, 0) # No scale or shift
+ add_op.add_input_tensor(fm_alpha)
+ add_op.add_input_tensor(fm_scaled)
+ add_op.set_output_tensor(ofm)
+ add_op.set_ifm_ofm_shapes()
+
+ DebugDatabase.add_optimised(op, add_op)
+ ifm.consumer_list.remove(op)
+ op = add_op
+
+ return op
+
+
def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
r"""Whenever there is a subgraph with this topology:
@@ -1648,6 +1700,7 @@ def tflite_optimise_graph(nng, arch):
convert_depthwise_to_conv,
convert_conv_to_fc,
convert_softmax,
+ convert_prelu,
optimise_strided_conv,
convert_hardswish_to_lut,
rewrite_fully_connected_input,