diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index c79f154a..1b70165e 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -34,7 +34,6 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .graph_optimiser_util import bypass_memory_only_ops from .graph_optimiser_util import calc_explicit_padding from .graph_optimiser_util import convert_depthwise_to_conv -from .graph_optimiser_util import convert_to_lut from .graph_optimiser_util import create_avg_pool_for_concat from .graph_optimiser_util import memory_only_ops from .graph_optimiser_util import move_splitsliceread_to_consumer @@ -42,6 +41,9 @@ from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence from .lstm import Lstm +from .lut import convert_to_lut +from .lut import create_lut_8bit_op +from .lut import create_lut_int16_op from .numeric_util import clamp_sigmoid from .numeric_util import full_shape from .numeric_util import round_away_zero @@ -1935,6 +1937,19 @@ def convert_mean_to_depthwise_conv(op, arch, nng): return op +def convert_ops_to_lut(op, arch, nng): + if op.type == Op.Exp: + if op.ifm.dtype == DataType.int8: + return create_lut_8bit_op(op, math.exp, "exp") + elif op.ifm.dtype == DataType.int16: + return create_lut_int16_op(op, math.exp, "exp") + else: + # Should already be catched in tflite supported ops + assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}" + + return op + + def optimise_quantize(op: Operation, arch, nng): if op.type == Op.Quantize and op.run_on_npu: @@ -2214,6 +2229,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): # Rewrite of operators op_rewrite_list = [ set_tensor_equivalence, + convert_ops_to_lut, convert_mean_to_depthwise_conv, convert_depthwise_to_conv, convert_conv_to_fc, |