diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 46d26c80..aaccce2c 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -832,14 +832,14 @@ def add_attrs_to_resizebilinear(op, arch): return op -def add_bias_tensor(op, arch): - if ("conv2d" in op.type.lower() or op.type.startswith("FullyConnected")) and not op.inputs[-1]: - # Add bias/scale tensor filled with zeros - weight_shape = op.inputs[1].shape - weight_sets = weight_shape[-1] - bias_values = [0] * weight_sets - scale_tens = create_const_tensor(op.name + "_bias", [weight_sets], DataType.int32, bias_values) - op.set_input_tensor(scale_tens, -1) +def fixup_bias_tensors(op, arch): + if op.needs_bias() and not op.inputs[-1]: + # Op has no bias, add bias tensor filled with zeros + nr_biases = op.inputs[1].shape[-1] + bias_values = [0] * nr_biases + bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values) + bias_tensor.quant_values = bias_tensor.values + op.set_input_tensor(bias_tensor, -1) return op @@ -870,7 +870,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): fixup_elementwise_with_scalars, reorder_depthwise_weights, fixup_resizebilinear, - add_bias_tensor, + fixup_bias_tensors, convert_mul_max_to_abs_or_lrelu, convert_lrelu, ] |