diff options
author | Johan Alfven <johan.alfven@arm.com> | 2023-04-24 13:35:40 +0200 |
---|---|---|
committer | Rickard Bolin <rickard.bolin@arm.com> | 2023-05-02 11:03:37 +0000 |
commit | ce5027328d2330d33bdfc5d5b016d171e4f8a2fc (patch) | |
tree | 16528c451c430d14d3221bfe968e94a54c6af33a /ethosu/vela/tflite_graph_optimiser.py | |
parent | 463f74b2ed65753811ddf7460e760aaf983ef5bb (diff) | |
download | ethos-u-vela-ce5027328d2330d33bdfc5d5b016d171e4f8a2fc.tar.gz |
MLBEDSW-2082: Add Exp support
- Added int8 and int16 Exp support, implemented as LUT.
- Added generic 8bit and 16bit LUT table functions following
the implementation in the latest reference. If new ops are added
by the reference, they can easily be implemented in Vela using
the generic functions.
- Moved convert_to_lut to lut.py to have all LUT related code in
one file.
- Updated SUPPORTED_OPS.md
Change-Id: I388e76ea4b39162313599a5341cfb9bad71a782c
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index c79f154a..1b70165e 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -34,7 +34,6 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .graph_optimiser_util import bypass_memory_only_ops from .graph_optimiser_util import calc_explicit_padding from .graph_optimiser_util import convert_depthwise_to_conv -from .graph_optimiser_util import convert_to_lut from .graph_optimiser_util import create_avg_pool_for_concat from .graph_optimiser_util import memory_only_ops from .graph_optimiser_util import move_splitsliceread_to_consumer @@ -42,6 +41,9 @@ from .graph_optimiser_util import needed_total_padding from .graph_optimiser_util import set_ifm_ofm_op_shapes from .graph_optimiser_util import set_tensor_equivalence from .lstm import Lstm +from .lut import convert_to_lut +from .lut import create_lut_8bit_op +from .lut import create_lut_int16_op from .numeric_util import clamp_sigmoid from .numeric_util import full_shape from .numeric_util import round_away_zero @@ -1935,6 +1937,19 @@ def convert_mean_to_depthwise_conv(op, arch, nng): return op +def convert_ops_to_lut(op, arch, nng): + if op.type == Op.Exp: + if op.ifm.dtype == DataType.int8: + return create_lut_8bit_op(op, math.exp, "exp") + elif op.ifm.dtype == DataType.int16: + return create_lut_int16_op(op, math.exp, "exp") + else: + # Should already be catched in tflite supported ops + assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}" + + return op + + def optimise_quantize(op: Operation, arch, nng): if op.type == Op.Quantize and op.run_on_npu: @@ -2214,6 +2229,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): # Rewrite of operators op_rewrite_list = [ set_tensor_equivalence, + convert_ops_to_lut, convert_mean_to_depthwise_conv, convert_depthwise_to_conv, convert_conv_to_fc, |