diff options
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 5 | ||||
-rw-r--r-- | ethosu/vela/numeric_util.py | 6 |
2 files changed, 4 insertions, 7 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 81d5a188..e7c15cdc 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -27,9 +27,9 @@ from . import scaling from .data_type import DataType from .errors import UnsupportedFeatureError from .ethos_u55_regs.ethos_u55_regs import resampling_mode +from .numeric_util import clamp_sigmoid from .numeric_util import full_shape from .numeric_util import round_away_zero -from .numeric_util import sigmoid from .operation import create_avgpool_nop from .operation import NpuBlockType from .operation import Operation @@ -447,6 +447,7 @@ def unfuse_activation_function(op, arch): return op + def fixup_unpack_output(tens, arch): op = tens.ops[0] if op.type in set(("Unpack", "StridedSlice")): @@ -974,7 +975,7 @@ def convert_lrelu(op, arch): def convert_tanh_sigmoid_to_lut(op, arch): # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution if op.type == "Sigmoid": - return convert_to_lut8(op, sigmoid) + return convert_to_lut8(op, clamp_sigmoid) elif op.type == "Tanh": return convert_to_lut8(op, math.tanh) return op diff --git a/ethosu/vela/numeric_util.py b/ethosu/vela/numeric_util.py index 3d26444a..4ebef8e5 100644 --- a/ethosu/vela/numeric_util.py +++ b/ethosu/vela/numeric_util.py @@ -77,17 +77,13 @@ def clamp_tanh(x): return y -def sigmoid(x): - return 1 / (1 + math.exp(-x)) - - def clamp_sigmoid(x): if x <= -8: y = 0.0 elif x >= 8: y = 1.0 else: - y = sigmoid(x) + y = 1 / (1 + math.exp(-x)) return y |