aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-08-26 18:21:28 +0200
committertim.hall <tim.hall@arm.com>2020-08-27 14:23:03 +0000
commita41cd4de2af1e43b76a2a33d78eeb2d90a88b757 (patch)
treee7f81ab5fbddca95928e2111fea2f6cff9b75679 /ethosu/vela/graph_optimiser.py
parent2abd3dd75bd3d20e1a3aeaf12362f9872b40fa0a (diff)
downloadethos-u-vela-a41cd4de2af1e43b76a2a33d78eeb2d90a88b757.tar.gz
Small fix for Softmax regression
Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: I287c24725126c169afec779b921e43c3ab26f739
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 46d26c80..aaccce2c 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -832,14 +832,14 @@ def add_attrs_to_resizebilinear(op, arch):
return op
-def add_bias_tensor(op, arch):
- if ("conv2d" in op.type.lower() or op.type.startswith("FullyConnected")) and not op.inputs[-1]:
- # Add bias/scale tensor filled with zeros
- weight_shape = op.inputs[1].shape
- weight_sets = weight_shape[-1]
- bias_values = [0] * weight_sets
- scale_tens = create_const_tensor(op.name + "_bias", [weight_sets], DataType.int32, bias_values)
- op.set_input_tensor(scale_tens, -1)
+def fixup_bias_tensors(op, arch):
+ if op.needs_bias() and not op.inputs[-1]:
+ # Op has no bias, add bias tensor filled with zeros
+ nr_biases = op.inputs[1].shape[-1]
+ bias_values = [0] * nr_biases
+ bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
+ bias_tensor.quant_values = bias_tensor.values
+ op.set_input_tensor(bias_tensor, -1)
return op
@@ -870,7 +870,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
fixup_elementwise_with_scalars,
reorder_depthwise_weights,
fixup_resizebilinear,
- add_bias_tensor,
+ fixup_bias_tensors,
convert_mul_max_to_abs_or_lrelu,
convert_lrelu,
]