aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py38
1 files changed, 36 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index e8218fcc..bea22a23 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1435,12 +1435,46 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
# This attribute means a different scaling calculation is used in order to match reference
op.low_precision_scaling = True
weight_scale = h * w
+ # Set zero points to 0 as they will be adjusted for with bias term
foq = ofmq.clone()
- foq.zero_point -= int(np.round(ifmq.zero_point * ifmq.scale_f32 / foq.scale_f32))
- op.forced_output_quantization = foq
+ foq.zero_point = 0
fiq = ifmq.clone()
fiq.zero_point = 0
op.forced_input_quantization = fiq
+ bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
+ # If the bias term is outside uint8 range, we need an Add op to apply it.
+ if bias_term < 0 or bias_term > 255:
+ intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
+ # Bias term has higher bitness (i32) than input/output (u8).
+ # 16 bits is enough since the bias is added/subtracted from a u8 value,
+ # the bias can only effectively assume values in the range [-255, 255].
+ intermediate.dtype = DataType.int16
+ intermediate.quantization.zero_point = 0
+ add_op = Operation(Op.Add, op.name + "_bias")
+ add_op.forced_output_quantization = foq
+ add_op.add_input_tensor(intermediate)
+ quant = QuantizationParameters()
+ quant.zero_point = 0
+ bias_term_tens = create_const_tensor(
+ op.name + "_bias",
+ [1, 1, 1, 1],
+ DataType.int16,
+ [bias_term],
+ np.int16,
+ quantization=quant,
+ quant_value_dtype=np.int16,
+ )
+ add_op.add_input_tensor(bias_term_tens)
+ add_op.set_output_tensor(op.ofm)
+ add_op.set_ifm_ofm_shapes()
+ add_op.activation = op.activation
+ op.activation = None
+ op.set_output_tensor(intermediate)
+ op.set_ifm_ofm_shapes()
+ # If not, we can just do it with the OFM zero point.
+ else:
+ foq.zero_point = bias_term
+ op.forced_output_quantization = foq
else:
assert inp.dtype == DataType.int8
# Use a depthwise to calculate the sum,