aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index c79f154a..1b70165e 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -34,7 +34,6 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .graph_optimiser_util import bypass_memory_only_ops
from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import convert_depthwise_to_conv
-from .graph_optimiser_util import convert_to_lut
from .graph_optimiser_util import create_avg_pool_for_concat
from .graph_optimiser_util import memory_only_ops
from .graph_optimiser_util import move_splitsliceread_to_consumer
@@ -42,6 +41,9 @@ from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
from .lstm import Lstm
+from .lut import convert_to_lut
+from .lut import create_lut_8bit_op
+from .lut import create_lut_int16_op
from .numeric_util import clamp_sigmoid
from .numeric_util import full_shape
from .numeric_util import round_away_zero
@@ -1935,6 +1937,19 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
return op
+def convert_ops_to_lut(op, arch, nng):
+ if op.type == Op.Exp:
+ if op.ifm.dtype == DataType.int8:
+ return create_lut_8bit_op(op, math.exp, "exp")
+ elif op.ifm.dtype == DataType.int16:
+ return create_lut_int16_op(op, math.exp, "exp")
+ else:
+ # Should already be catched in tflite supported ops
+ assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}"
+
+ return op
+
+
def optimise_quantize(op: Operation, arch, nng):
if op.type == Op.Quantize and op.run_on_npu:
@@ -2214,6 +2229,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
# Rewrite of operators
op_rewrite_list = [
set_tensor_equivalence,
+ convert_ops_to_lut,
convert_mean_to_depthwise_conv,
convert_depthwise_to_conv,
convert_conv_to_fc,