aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index b3474147..df6b5759 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -19,7 +19,6 @@
import numpy as np
from . import rewrite_graph
-from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
from .graph_optimiser_util import bypass_memory_only_ops
@@ -32,6 +31,7 @@ from .graph_optimiser_util import set_tensor_equivalence
from .lut import convert_to_lut
from .operation import ExplicitScaling
from .operation import Op
+from .operation import RoundingMode
from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
from .operation_util import create_pad_nop
@@ -115,7 +115,7 @@ def calc_scaling_avgpool(op, arch, nng):
multiplier.append(numerator // kernel_wh)
shift.append(30 + k)
- op.rounding_mode = NpuRoundingMode.NATURAL
+ op.rounding_mode = RoundingMode.HalfUp
op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
return op
@@ -399,9 +399,9 @@ def rewrite_rescale(op, arch, nng):
explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
if double_round and scale32:
- rounding_mode = NpuRoundingMode.TFL
+ rounding_mode = RoundingMode.TFLite
else:
- rounding_mode = NpuRoundingMode.NATURAL
+ rounding_mode = RoundingMode.HalfUp
if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
assert len(multiplier) == len(shift) == len(prev_op.bias.values)