aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-04-24 13:35:40 +0200
committerRickard Bolin <rickard.bolin@arm.com>2023-05-02 11:03:37 +0000
commitce5027328d2330d33bdfc5d5b016d171e4f8a2fc (patch)
tree16528c451c430d14d3221bfe968e94a54c6af33a /ethosu/vela/tflite_graph_optimiser.py
parent463f74b2ed65753811ddf7460e760aaf983ef5bb (diff)
downloadethos-u-vela-ce5027328d2330d33bdfc5d5b016d171e4f8a2fc.tar.gz
MLBEDSW-2082: Add Exp support
- Added int8 and int16 Exp support, implemented as LUT. - Added generic 8bit and 16bit LUT table functions following the implementation in the latest reference. If new ops are added by the reference, they can easily be implemented in Vela using the generic functions. - Moved convert_to_lut to lut.py to have all LUT related code in one file. - Updated SUPPORTED_OPS.md Change-Id: I388e76ea4b39162313599a5341cfb9bad71a782c Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index c79f154a..1b70165e 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -34,7 +34,6 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .graph_optimiser_util import bypass_memory_only_ops
from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import convert_depthwise_to_conv
-from .graph_optimiser_util import convert_to_lut
from .graph_optimiser_util import create_avg_pool_for_concat
from .graph_optimiser_util import memory_only_ops
from .graph_optimiser_util import move_splitsliceread_to_consumer
@@ -42,6 +41,9 @@ from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
from .lstm import Lstm
+from .lut import convert_to_lut
+from .lut import create_lut_8bit_op
+from .lut import create_lut_int16_op
from .numeric_util import clamp_sigmoid
from .numeric_util import full_shape
from .numeric_util import round_away_zero
@@ -1935,6 +1937,19 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
return op
+def convert_ops_to_lut(op, arch, nng):
+ if op.type == Op.Exp:
+ if op.ifm.dtype == DataType.int8:
+ return create_lut_8bit_op(op, math.exp, "exp")
+ elif op.ifm.dtype == DataType.int16:
+ return create_lut_int16_op(op, math.exp, "exp")
+ else:
+ # Should already be catched in tflite supported ops
+ assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}"
+
+ return op
+
+
def optimise_quantize(op: Operation, arch, nng):
if op.type == Op.Quantize and op.run_on_npu:
@@ -2214,6 +2229,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
# Rewrite of operators
op_rewrite_list = [
set_tensor_equivalence,
+ convert_ops_to_lut,
convert_mean_to_depthwise_conv,
convert_depthwise_to_conv,
convert_conv_to_fc,