diff options
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index b3474147..df6b5759 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -19,7 +19,6 @@ import numpy as np from . import rewrite_graph -from .api import NpuRoundingMode from .data_type import DataType from .debug_database import DebugDatabase from .graph_optimiser_util import bypass_memory_only_ops @@ -32,6 +31,7 @@ from .graph_optimiser_util import set_tensor_equivalence from .lut import convert_to_lut from .operation import ExplicitScaling from .operation import Op +from .operation import RoundingMode from .operation_util import create_add_nop from .operation_util import create_avgpool_nop from .operation_util import create_pad_nop @@ -115,7 +115,7 @@ def calc_scaling_avgpool(op, arch, nng): multiplier.append(numerator // kernel_wh) shift.append(30 + k) - op.rounding_mode = NpuRoundingMode.NATURAL + op.rounding_mode = RoundingMode.HalfUp op.explicit_scaling = ExplicitScaling(False, shift, multiplier) return op @@ -399,9 +399,9 @@ def rewrite_rescale(op, arch, nng): explicit_scaling = ExplicitScaling(per_channel, shift, multiplier) if double_round and scale32: - rounding_mode = NpuRoundingMode.TFL + rounding_mode = RoundingMode.TFLite else: - rounding_mode = NpuRoundingMode.NATURAL + rounding_mode = RoundingMode.HalfUp if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected: assert len(multiplier) == len(shift) == len(prev_op.bias.values) |