aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_graph_optimiser.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2021-03-11 14:59:06 +0100
committerLouis Verhaard <louis.verhaard@arm.com>2021-03-16 10:37:20 +0100
commitc822d62ba27b874a130e9d8d434c12b419d10d62 (patch)
tree4450a032513b537ffd3545cb2bdb6e052339beb2 /ethosu/vela/test/test_graph_optimiser.py
parent8ba0792731d47de64a59d93359340f3c88fc4a62 (diff)
downloadethos-u-vela-c822d62ba27b874a130e9d8d434c12b419d10d62.tar.gz
MLBEDSW-4223: Full support for PAD operator
- Added full support for PAD operator - Hardware padding is still used whenever possible - Bug fix Pad followed by max pool if IFM contains negative values Change-Id: Ifc64d1943737d94466f5e2821009dab12a49a965 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/test/test_graph_optimiser.py')
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py194
1 files changed, 157 insertions, 37 deletions
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 285b3ac5..d9e171d6 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -23,7 +23,7 @@ from ethosu.vela.data_type import DataType
from ethosu.vela.graph_optimiser import calc_explicit_padding
from ethosu.vela.graph_optimiser import convert_batched_fc_shape
from ethosu.vela.graph_optimiser import optimise_graph_a
-from ethosu.vela.graph_optimiser import optimise_pad
+from ethosu.vela.graph_optimiser import replace_pad_by_hw_pad
from ethosu.vela.graph_optimiser import rewrite_fully_connected_input
from ethosu.vela.nn_graph import Graph
from ethosu.vela.operation import Op
@@ -116,47 +116,92 @@ def test_calc_explicit_padding(test_input, expected_result):
assert (before, after) == expected_result
-def test_optimise_pad():
+def create_pad_and_conv2d(
+ in_shape,
+ out_shape,
+ padding,
+ in_dtype=DataType.int8,
+ out_dtype=DataType.int8,
+ pad_dtype=DataType.int32,
+ pad_setting=Padding.VALID,
+ kernel_size=3,
+):
+ """Creates Pad operator followed by a conv2d operator"""
+ qp = testutil.default_quant_params()
+ in0 = Tensor(in_shape, in_dtype, "in")
+ in0.quantization = qp
+ pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+ out = Tensor(out_shape, out_dtype, "out")
+ out.quantization = qp.clone()
+ op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+ op.run_on_npu = True
+ conv_out_tens = Tensor(in_shape, in_dtype, "output")
+ conv_out_tens.quantization = qp.clone()
+ weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
+ weight_tens.values = np.zeros(weight_tens.shape)
+ weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
+ weight_tens.quantization = qp.clone()
+ bias_tens = Tensor(out_shape, pad_dtype, "biases")
+ attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
+ attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+ conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
+ conv2d_op.add_input_tensor(out)
+ conv2d_op.run_on_npu = True
+ return op, conv2d_op
+
+
+def test_pad_followed_by_conv_is_removed():
"""
Tests that the PAD operator is bypassed when followed by a convolution operator,
and that the padding of the convolution operation is correctly updated
"""
- # Create Pad operation followed by Conv2D
- quant = testutil.default_quant_params()
- in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
- in_tens.quantization = quant
- pad_input = create_const_tensor("pad_input", [4, 2], DataType.int32, [[0, 0], [2, 1], [1, 1], [0, 0]])
- temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
- temp_tens.quantization = quant.clone()
- out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
- out_tens.quantization = quant.clone()
- weight_tens = Tensor([5, 3, 64, 64], DataType.uint8, "weights")
- weight_tens.values = np.zeros(weight_tens.shape)
- weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
- weight_tens.quantization = quant.clone()
-
- bias_tens = Tensor([64], DataType.int32, "biases")
- pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
- attrs = {"padding": Padding.VALID, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
- attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
- pad_op.run_on_npu = True
- conv2d_op = testutil.create_op(Op.Conv2D, [temp_tens, weight_tens, bias_tens], out_tens, attrs)
- conv2d_op.run_on_npu = True
- nng = Graph()
- sg = testutil.create_subgraph([pad_op, conv2d_op])
- nng.subgraphs.append(sg)
+ pad_op, conv2d_op = create_pad_and_conv2d(
+ in_shape=[1, 76, 75, 64], out_shape=[1, 76, 75, 64], padding=[[0, 0], [2, 1], [1, 1], [0, 0]], kernel_size=4
+ )
+ nng = testutil.create_graph([pad_op, conv2d_op])
arch = testutil.create_arch()
- optimise_pad(conv2d_op, nng, arch)
+ replace_pad_by_hw_pad(conv2d_op, nng, arch)
- op = sg.output_tensors[0].ops[0]
- assert op.type == Op.Conv2D
+ op = nng.subgraphs[0].output_tensors[0].ops[0]
+ assert op.type == Op.Conv2DBias
assert op.attrs["padding"] == Padding.EXPLICIT
assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
assert op.ifm.shape == [1, 76, 75, 64]
assert pad_op not in op.ifm.ops
+leading_pad_test_data = [
+ (2, 2, 11, True),
+ (1, 2, 11, False),
+ (2, 1, 11, False),
+ (5, 2, 11, True),
+]
+
+
+@pytest.mark.parametrize("top, left, kernel_size, expect_pad_removed", leading_pad_test_data)
+def test_leading_pad_size(top, left, kernel_size, expect_pad_removed):
+ # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
+ out_shape = [1, 11 + left, 11 + top, 1]
+ padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
+ pad_op, conv2d_op = create_pad_and_conv2d(
+ in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size
+ )
+ nng = testutil.create_graph([pad_op, conv2d_op])
+ arch = testutil.create_arch()
+ replace_pad_by_hw_pad(conv2d_op, nng, arch)
+ op = nng.subgraphs[0].output_tensors[0].ops[0]
+ if expect_pad_removed:
+ assert op.attrs["padding"] == Padding.EXPLICIT
+ assert "explicit_padding" in op.attrs
+ assert op.ifm.shape == op.ofm.shape
+ assert pad_op not in op.ifm.ops
+ else:
+ assert pad_op in op.ifm.ops
+ assert op.attrs["padding"] == Padding.VALID
+ assert "explicit_padding" not in op.attrs
+
+
def test_optimise_pad_followed_by_avg_pool():
"""
Tests that the PAD operator is bypassed when followed by a average pool operator,
@@ -166,7 +211,8 @@ def test_optimise_pad_followed_by_avg_pool():
quant = testutil.default_quant_params()
in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
in_tens.quantization = quant
- pad_input = create_const_tensor("pad_input", [4, 2], DataType.int32, [[0, 0], [2, 1], [1, 1], [0, 0]])
+ # Test with 3x2 input tensor
+ pad_input = create_const_tensor("pad_input", [3, 2], DataType.int32, [[2, 2], [1, 1], [0, 0]])
temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
temp_tens.quantization = quant.clone()
out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
@@ -185,25 +231,99 @@ def test_optimise_pad_followed_by_avg_pool():
pad_op.run_on_npu = True
conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
conv2d_op.run_on_npu = True
- nng = Graph()
- sg = testutil.create_subgraph([pad_op, conv2d_op])
- nng.subgraphs.append(sg)
+ nng = testutil.create_graph([pad_op, conv2d_op])
arch = testutil.create_arch()
- optimise_pad(conv2d_op, nng, arch)
+ replace_pad_by_hw_pad(conv2d_op, nng, arch)
- op = sg.output_tensors[0].ops[0]
+ op = nng.subgraphs[0].output_tensors[0].ops[0]
assert op.type == Op.DepthwiseConv2DBias
assert op.attrs["padding"] == Padding.EXPLICIT
- assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
+ assert op.attrs["explicit_padding"] == (2, 1, 2, 1)
assert op.ifm.shape == [1, 76, 75, 64]
assert pad_op not in op.ifm.ops
# Check that bias and weight tensors have been added
assert op.bias.shape == [64]
- print("op.weights:", op.weights)
assert op.weights.shape == [5, 3, 1, 64]
+pad_avg_pool_test_data = [
+ ((3, 3), (1, 1, 1, 1), True),
+ ((3, 3), (2, 1, 1, 1), False),
+ ((3, 3), (1, 2, 1, 1), False),
+ ((3, 3), (1, 1, 2, 1), False),
+ ((3, 3), (1, 1, 1, 2), False),
+ ((2, 4), (1, 2, 1, 2), True),
+ ((5, 3), (2, 1, 2, 1), True),
+ ((5, 3), (0, 1, 2, 1), True),
+ ((5, 3), (2, 0, 2, 1), True),
+ ((5, 3), (2, 1, 0, 1), True),
+ ((5, 3), (2, 1, 0, 1), True),
+ ((4, 4), (2, 2, 2, 2), True),
+ ((4, 4), (1, 2, 2, 2), False),
+ ((4, 4), (2, 1, 2, 2), False),
+ ((4, 4), (2, 2, 1, 2), False),
+ ((4, 4), (2, 2, 2, 1), False),
+]
+
+
+@pytest.mark.parametrize("k_size, padding, expect_pad_removed", pad_avg_pool_test_data)
+def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
+ # Tests PAD followed by AvgPool
+ k_w, k_h = k_size
+ top, left, bottom, right = padding
+ pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
+ dtype = DataType.int8
+ qp = testutil.default_quant_params()
+ in_shape = [1, 15, 17, 8]
+ out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
+ in0 = Tensor(in_shape, dtype, "in")
+ in0.quantization = qp
+ pad_tensor = create_const_tensor(
+ name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
+ )
+ out = Tensor(out_shape, dtype, "out")
+ out.quantization = qp.clone()
+ pad_op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+ pool_out_tens = Tensor(in_shape, dtype, "output")
+ pool_out_tens.quantization = qp.clone()
+ attrs = {
+ "padding": Padding.VALID,
+ "ksize": [1, k_w, k_h, 1],
+ "stride_w": 1,
+ "stride_h": 1,
+ "dilation_w_factor": 1,
+ "dilation_h_factor": 1,
+ }
+ pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
+ pool_op.add_input_tensor(out)
+ pad_op.run_on_npu = True
+ pool_op.run_on_npu = True
+ nng = testutil.create_graph([pad_op, pool_op])
+ arch = testutil.create_arch()
+ nng = optimise_graph_a(nng, arch)
+ sg = nng.subgraphs[0]
+ all_ops = sg.get_all_ops()
+ print("all_ops: ", all_ops)
+ # Pad should not be in the graph anymore, it should either have been removed or rewritten
+ assert not any(op.type == Op.Pad for op in all_ops)
+ op = nng.subgraphs[0].output_tensors[0].ops[0]
+ if expect_pad_removed:
+ # Expect rewrite to depthwise, PAD is removed
+ assert op.type == Op.DepthwiseConv2DBias
+ assert op.attrs["padding"] == Padding.EXPLICIT
+ assert any(pad > 0 for pad in op.attrs["explicit_padding"])
+ assert op.ifm.shape == op.ofm.shape
+ # Check that bias and weight tensors have been added
+ assert len(op.bias.shape) > 0
+ assert op.weights.shape is not None
+ else:
+ # Pad should have been rewritten to a number of average pool operations
+ assert all(op.type in (Op.AvgPool, Op.Const) for op in all_ops)
+ assert pool_op.type == Op.AvgPool
+ assert pool_op.attrs["padding"] == Padding.VALID
+
+
def test_remove_reshape():
"""
Tests that the expected reshape are removed in graph_optimisation