diff options
Diffstat (limited to 'ethosu/vela/test/test_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/test/test_graph_optimiser.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py index b01b07c3..55980e3d 100644 --- a/ethosu/vela/test/test_graph_optimiser.py +++ b/ethosu/vela/test/test_graph_optimiser.py @@ -22,6 +22,7 @@ from ethosu.vela.data_type import DataType from ethosu.vela.graph_optimiser import convert_batched_fc_shape from ethosu.vela.graph_optimiser import optimise_graph_a from ethosu.vela.graph_optimiser import optimise_pad +from ethosu.vela.graph_optimiser import rewrite_fully_connected_input from ethosu.vela.nn_graph import Graph from ethosu.vela.operation import Op from ethosu.vela.operation import Padding @@ -47,8 +48,8 @@ def test_convert_batched_fc(): prev_op.ifm_shapes = op.ifm_shapes.copy() prev_op.ofm_shapes = op.ofm_shapes.copy() + rewrite_fully_connected_input(op, None, None) conv_op = convert_batched_fc_shape(op, None, None) - assert conv_op.ifm == prev_op.ifm assert conv_op.ofm == prev_op.ofm assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8]) @@ -68,6 +69,7 @@ def test_convert_batched_fc(): prev_op.ifm_shapes = op.ifm_shapes.copy() prev_op.ofm_shapes = op.ofm_shapes.copy() + rewrite_fully_connected_input(op, None, None) conv_op = convert_batched_fc_shape(op, None, None) assert conv_op.ifm == prev_op.ifm |