aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-01-29 11:51:31 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-01-29 16:05:03 +0100
commit2c2522dd44229a03d3d778cd239478fedc19ee57 (patch)
tree610bd611f9783f71cf79f4c2e8466789cacfd429 /ethosu/vela/graph_optimiser.py
parent7bada4039d01836c995a12251034777055e1848a (diff)
downloadethos-u-vela-2c2522dd44229a03d3d778cd239478fedc19ee57.tar.gz
MLBEDSW-3772 Fix FC with changed inp shape
When FC input is fixed by changing ifm_shape, avoid_NHCWB16 must be set to ifm. -Fixed issue with ResizeBilinear -Changed to post order for concat ops in graph optimisation Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: Ie0c6a86637c210c0833ae9b2f8e7c494c5d4f66e
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py26
1 files changed, 20 insertions, 6 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index f1b2d35c..ab4d916e 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -59,7 +59,7 @@ def remove_passthrough_tensor(tens, arch, nng):
return tens
-def rewrite_concat_ops(op, arch, nng):
+def rewrite_concat_ops(op, arch):
if not op.run_on_npu or not op.type.is_concat_op():
return op
@@ -283,8 +283,8 @@ def convert_resizebilinear_to_2x2_pool(op):
op.attrs["padding"] = Padding.SAME
op.inputs[0].resampling_mode = resampling_mode.NEAREST
- upscaled_shape = op.ifm_shape[0].get_hw_as_list()
- out_shape = op.ofm_shape[0].get_hw_as_list()
+ upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
+ out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
return op
@@ -346,6 +346,20 @@ def convert_nop_split_to_identity(op, arch, nng):
return op
+def rewrite_fully_connected_input(op, arch, nng):
+ if op.type == Op.FullyConnected:
+ n_in_elems = op.weights.shape[-2]
+ elms = op.ifm.elements()
+ batch_size = elms // n_in_elems
+ assert batch_size * n_in_elems == elms
+
+ if op.ifm.shape != [batch_size, n_in_elems]:
+ op.ifm.avoid_NHCWB16 = True
+
+ op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
+ return op
+
+
def convert_batched_fc_shape(op, arch, nng):
if op.type == Op.FullyConnected:
# Check if the first dimension indicates batching
@@ -1199,9 +1213,8 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
# Handle Concat Ops
for idx, sg in enumerate(nng.subgraphs):
# rewrite graph pass
- nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [rewrite_concat_ops], rewrite_unsupported=False,
- )
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
+ sg.refresh_after_modification()
# Handle Split Ops
for idx, sg in enumerate(nng.subgraphs):
@@ -1232,6 +1245,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
convert_conv_to_fc,
convert_softmax,
optimise_strided_conv,
+ rewrite_fully_connected_input,
convert_batched_fc_shape,
fixup_conv2d_backprop,
fixup_relus_with_differing_ifm_ofm_scaling,