aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-16 13:08:06 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-21 07:34:05 +0100
commitbf31d647dc5df47410ee577b12427ddf076d816b (patch)
tree85ddd620916565aa8565d072b764ca4918b405a1 /ethosu/vela/graph_optimiser.py
parent2349d429d926e258e9a61d34c7fd97660ab9fb98 (diff)
downloadethos-u-vela-bf31d647dc5df47410ee577b12427ddf076d816b.tar.gz
MLBEDSW-3645 4D class for op ifm/ofm shapes
Add 4D shape class for op Ifm/ofm shapes Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: Ic0a98da9d2f9d085605e39a9ab5a26bad6e702a3
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py42
1 files changed, 26 insertions, 16 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index fdb0fae0..1128a311 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -37,6 +37,7 @@ from .operation import Op
from .operation import Operation
from .operation import Padding
from .operation_util import create_avgpool_nop
+from .shape4d import Shape4D
from .softmax import SoftMax
from .tensor import check_quantized_tens_scaling_equal
from .tensor import create_const_tensor
@@ -82,6 +83,7 @@ def rewrite_concat(tens, arch, nng):
new_op.run_on_npu = True
tens.ops.append(new_op)
DebugDatabase.add_optimised(concat_op, new_op)
+ new_op.set_ifm_ofm_shapes()
assert tens.shape[axis] == offset
# If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -121,7 +123,8 @@ def rewrite_split(tens, arch, nng):
if out == tens:
break
axis_4D = axis + (4 - len(out.shape))
- offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
+
+ offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)
# If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
if (offset_start[-1] % 16) != 0:
@@ -132,6 +135,7 @@ def rewrite_split(tens, arch, nng):
new_op.attrs["split_start"] = offset_start
new_op.run_on_npu = True
new_op.set_output_tensor(tens)
+ new_op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(split_op, new_op)
return tens
@@ -189,6 +193,7 @@ def fixup_conv2d_backprop(op, arch, nng):
if op.type == Op.Conv2DBackpropInput:
# flip the inputs
op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
+ op.set_ifm_ofm_shapes()
op.type = Op.Conv2DBackpropInputSwitchedBias
# Update strides
@@ -216,8 +221,7 @@ def convert_resizebilinear_1x1_to_add(op):
# Set the add inputs
op.inputs[1] = op.inputs[0]
op.inputs[0] = tens
- op.ifm_shapes = []
- op.ofm_shapes = []
+ op.set_ifm_ofm_shapes()
return op
@@ -323,14 +327,14 @@ def convert_batched_fc_shape(op, arch, nng):
ofm = op.outputs[0]
# Check if the FC is 2D and first dimension indicates batching
# TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
- if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1:
+ if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1:
n = ifm.shape[0]
batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
h, w = batching_split.get(n, (1, n))
prev_op = ifm.ops[0]
desired_shape = [1, h, w, ifm.shape[-1]]
- op.ifm_shapes[0] = desired_shape
+ op.ifm_shapes[0] = Shape4D(desired_shape)
if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
# There is a preceding Reshape
@@ -356,7 +360,7 @@ def convert_batched_fc_shape(op, arch, nng):
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
desired_shape = [1, h, w, ofm.shape[-1]]
- op.ofm_shapes[0] = desired_shape
+ op.ofm_shapes[0] = Shape4D(desired_shape)
if (
len(ofm.consumer_list) == 1
@@ -395,6 +399,7 @@ def fixup_pack_input(op, arch, nng):
reshape_op.attrs["new_shape"] = desired_shape
reshape_op.inputs = [inp, new_shape_tens]
reshape_op.set_output_tensor(reshape_out)
+ reshape_op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, reshape_op)
op.inputs[idx] = reshape_out
@@ -413,6 +418,7 @@ def unfuse_activation_function(op, arch, nng):
act_op.set_output_tensor(out_tens)
act_op.add_input_tensor(intermediate_tens)
op.set_output_tensor(intermediate_tens)
+ act_op.set_ifm_ofm_shapes()
return op
@@ -457,7 +463,7 @@ def fixup_stridedslice_output(tens, arch, nng):
new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
for idx, out_tens in enumerate(op.outputs):
- op.ofm_shapes[idx] = new_shape_tens
+ op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape)
reshape_in = out_tens.clone("_reshaped")
reshape_in.set_all_shapes(reshape_input_shape)
reshape_in.ops = [op]
@@ -466,6 +472,7 @@ def fixup_stridedslice_output(tens, arch, nng):
reshape_op.attrs["new_shape"] = reshape_input_shape
reshape_op.inputs = [reshape_in, new_shape_tens]
reshape_op.set_output_tensor(out_tens)
+ reshape_op.set_ifm_ofm_shapes()
op.outputs[idx] = reshape_in
@@ -493,6 +500,7 @@ def fixup_unpack_output(tens, arch, nng):
reshape_op.attrs["new_shape"] = reshape_input_shape
reshape_op.inputs = [reshape_in, new_shape_tens]
reshape_op.set_output_tensor(out_tens)
+ reshape_op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, reshape_op)
op.outputs[idx] = reshape_in
@@ -588,7 +596,8 @@ def convert_conv_to_fc(op, arch, nng):
# caching/double buffering for the weights.
# (Weights dont need to be reloaded for convs when IFM H and W are 1)
if op.type == Op.Conv2DBias:
- _, h, w, _ = op.ifm_shapes[0]
+ h = op.ifm_shapes[0].height
+ w = op.ifm_shapes[0].width
kh, kw, _, _ = op.inputs[1].shape
if h == 1 and w == 1 and kh == 1 and kw == 1:
# Overwrite this op as a Fully Connected Op
@@ -616,9 +625,11 @@ def convert_conv_to_fc(op, arch, nng):
reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
reshape_op.set_output_tensor(orig_ofm_tensor)
+ reshape_op.set_ifm_ofm_shapes()
# Replace this ops OFM to point to the 2D tensor
op.outputs[0] = fc_ofm_tensor
+ op.set_ifm_ofm_shapes()
# Record optimisation in debug database
DebugDatabase.add_optimised(op, reshape_op)
DebugDatabase.add_optimised(op, op)
@@ -649,6 +660,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
relu_fused_op.add_input_tensor(ifm)
relu_fused_op.set_output_tensor(ofm)
+ relu_fused_op.set_ifm_ofm_shapes()
op = relu_fused_op
return op
@@ -668,8 +680,8 @@ def fixup_act_reorder(op, arch, nng):
act_op_out = act_op.inputs[0].clone("_acted")
act_op_out.quantization = op.outputs[0].quantization.clone()
act_op.set_output_tensor(act_op_out)
- act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1)
- act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1)
+ act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape)
+ act_op.ofm_shapes[0] = Shape4D(act_op_out.shape)
# Update the consumer list
act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -839,6 +851,7 @@ def convert_lrelu_to_mul_max(op, arch):
mul_alpha.add_input_tensor(alpha_tens)
fm_alpha = ofm.clone(op.name + "_alpha")
mul_alpha.set_output_tensor(fm_alpha)
+ mul_alpha.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, mul_alpha)
if check_quantized_tens_scaling_equal(ifm, ofm):
@@ -860,6 +873,7 @@ def convert_lrelu_to_mul_max(op, arch):
mul_identity.add_input_tensor(identity_tens)
fm_id = ofm.clone(op.name + "_id")
mul_identity.set_output_tensor(fm_id)
+ mul_identity.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, mul_identity)
# Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
@@ -890,7 +904,7 @@ def convert_to_lut(op, lut_values, lut_name):
quantization.zero_point = 0
tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
op.add_input_tensor(tens)
- op.ifm_shapes.append(full_shape(4, tens.shape, 1))
+ op.ifm_shapes.append(Shape4D(tens.shape))
# The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
# so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
@@ -1158,11 +1172,7 @@ def optimise_graph_b(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# combined rewrite graph pass
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng,
- sg,
- arch,
- [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split],
- [set_ifm_ofm_op_shapes],
+ nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [],
)
if verbose_graph: