aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_graph_optimiser.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-01-29 11:51:31 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-01-29 16:05:03 +0100
commit2c2522dd44229a03d3d778cd239478fedc19ee57 (patch)
tree610bd611f9783f71cf79f4c2e8466789cacfd429 /ethosu/vela/test/test_graph_optimiser.py
parent7bada4039d01836c995a12251034777055e1848a (diff)
downloadethos-u-vela-2c2522dd44229a03d3d778cd239478fedc19ee57.tar.gz
MLBEDSW-3772 Fix FC with changed inp shape
When FC input is fixed by changing ifm_shape, avoid_NHCWB16 must be set to ifm. -Fixed issue with ResizeBilinear -Changed to post order for concat ops in graph optimisation Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: Ie0c6a86637c210c0833ae9b2f8e7c494c5d4f66e
Diffstat (limited to 'ethosu/vela/test/test_graph_optimiser.py')
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index b01b07c3..55980e3d 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -22,6 +22,7 @@ from ethosu.vela.data_type import DataType
from ethosu.vela.graph_optimiser import convert_batched_fc_shape
from ethosu.vela.graph_optimiser import optimise_graph_a
from ethosu.vela.graph_optimiser import optimise_pad
+from ethosu.vela.graph_optimiser import rewrite_fully_connected_input
from ethosu.vela.nn_graph import Graph
from ethosu.vela.operation import Op
from ethosu.vela.operation import Padding
@@ -47,8 +48,8 @@ def test_convert_batched_fc():
prev_op.ifm_shapes = op.ifm_shapes.copy()
prev_op.ofm_shapes = op.ofm_shapes.copy()
+ rewrite_fully_connected_input(op, None, None)
conv_op = convert_batched_fc_shape(op, None, None)
-
assert conv_op.ifm == prev_op.ifm
assert conv_op.ofm == prev_op.ofm
assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8])
@@ -68,6 +69,7 @@ def test_convert_batched_fc():
prev_op.ifm_shapes = op.ifm_shapes.copy()
prev_op.ofm_shapes = op.ofm_shapes.copy()
+ rewrite_fully_connected_input(op, None, None)
conv_op = convert_batched_fc_shape(op, None, None)
assert conv_op.ifm == prev_op.ifm