aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 812e8e9a..61ce1c96 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -130,7 +130,7 @@ def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
return True
-def get_rounding_mode(op: Operation) -> NpuRoundingMode:
+def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
"""Specifies type of rounding to be used"""
rounding_mode = NpuRoundingMode.TFL
if op.type == Op.ResizeBilinear:
@@ -140,7 +140,12 @@ def get_rounding_mode(op: Operation) -> NpuRoundingMode:
and op.ifm.dtype == DataType.int16
):
rounding_mode = NpuRoundingMode.NATURAL
- elif op.type.is_avgpool_op() and op.memory_function == Op.ConcatSliceWrite and op.kernel.elements_wh() == 1:
+ elif (
+ not fused_quantize
+ and op.type.is_avgpool_op()
+ and op.memory_function == Op.ConcatSliceWrite
+ and op.kernel.elements_wh() == 1
+ ):
rounding_mode = NpuRoundingMode.NATURAL
rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
return rounding_mode
@@ -353,14 +358,14 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit
if cmd.scale_tensor is not None:
npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
npu_op.activation = create_npu_activation(op)
- npu_op.rounding_mode = get_rounding_mode(op)
+ npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
+ npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
if not op.type.is_elementwise_op():
npu_op.padding = create_padding(cmd, op)
npu_op.kernel = to_npu_kernel(op.kernel)
npu_op.ifm_upscale = get_upscale(op)
- npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
return npu_op