aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 46d26c80..aaccce2c 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -832,14 +832,14 @@ def add_attrs_to_resizebilinear(op, arch):
return op
-def add_bias_tensor(op, arch):
- if ("conv2d" in op.type.lower() or op.type.startswith("FullyConnected")) and not op.inputs[-1]:
- # Add bias/scale tensor filled with zeros
- weight_shape = op.inputs[1].shape
- weight_sets = weight_shape[-1]
- bias_values = [0] * weight_sets
- scale_tens = create_const_tensor(op.name + "_bias", [weight_sets], DataType.int32, bias_values)
- op.set_input_tensor(scale_tens, -1)
+def fixup_bias_tensors(op, arch):
+ if op.needs_bias() and not op.inputs[-1]:
+ # Op has no bias, add bias tensor filled with zeros
+ nr_biases = op.inputs[1].shape[-1]
+ bias_values = [0] * nr_biases
+ bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
+ bias_tensor.quant_values = bias_tensor.values
+ op.set_input_tensor(bias_tensor, -1)
return op
@@ -870,7 +870,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
fixup_elementwise_with_scalars,
reorder_depthwise_weights,
fixup_resizebilinear,
- add_bias_tensor,
+ fixup_bias_tensors,
convert_mul_max_to_abs_or_lrelu,
convert_lrelu,
]