diff options
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 38 |
1 files changed, 36 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index e8218fcc..bea22a23 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -1435,12 +1435,46 @@ def convert_mean_to_depthwise_conv(op, arch, nng): # This attribute means a different scaling calculation is used in order to match reference op.low_precision_scaling = True weight_scale = h * w + # Set zero points to 0 as they will be adjusted for with bias term foq = ofmq.clone() - foq.zero_point -= int(np.round(ifmq.zero_point * ifmq.scale_f32 / foq.scale_f32)) - op.forced_output_quantization = foq + foq.zero_point = 0 fiq = ifmq.clone() fiq.zero_point = 0 op.forced_input_quantization = fiq + bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32) + # If the bias term is outside uint8 range, we need an Add op to apply it. + if bias_term < 0 or bias_term > 255: + intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True) + # Bias term has higher bitness (i32) than input/output (u8). + # 16 bits is enough since the bias is added/subtracted from a u8 value, + # the bias can only effectively assume values in the range [-255, 255]. + intermediate.dtype = DataType.int16 + intermediate.quantization.zero_point = 0 + add_op = Operation(Op.Add, op.name + "_bias") + add_op.forced_output_quantization = foq + add_op.add_input_tensor(intermediate) + quant = QuantizationParameters() + quant.zero_point = 0 + bias_term_tens = create_const_tensor( + op.name + "_bias", + [1, 1, 1, 1], + DataType.int16, + [bias_term], + np.int16, + quantization=quant, + quant_value_dtype=np.int16, + ) + add_op.add_input_tensor(bias_term_tens) + add_op.set_output_tensor(op.ofm) + add_op.set_ifm_ofm_shapes() + add_op.activation = op.activation + op.activation = None + op.set_output_tensor(intermediate) + op.set_ifm_ofm_shapes() + # If not, we can just do it with the OFM zero point. + else: + foq.zero_point = bias_term + op.forced_output_quantization = foq else: assert inp.dtype == DataType.int8 # Use a depthwise to calculate the sum, |