aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-01 16:02:29 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-18 16:33:32 +0100
commit2349d429d926e258e9a61d34c7fd97660ab9fb98 (patch)
treeb5151d0f12428e47d64b1fb2ce4f2f8c19304a0d /ethosu/vela/graph_optimiser.py
parent528a56df829b65f7a2c61953650b123c461095f7 (diff)
downloadethos-u-vela-2349d429d926e258e9a61d34c7fd97660ab9fb98.tar.gz
MLBEDSW-3654 Add/use op ifm/ofm shapes
Add ifm/ofm shapes to op Changed to rely on these shapes Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py69
1 files changed, 53 insertions, 16 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 4806001f..fdb0fae0 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -75,7 +75,7 @@ def rewrite_concat(tens, arch, nng):
new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
new_op.inputs = [inp]
new_op.outputs = [tens]
- new_op.attrs["concat_axis"] = axis
+ new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape))
new_op.attrs["concat_start"] = offset
offset += inp.shape[axis]
new_op.attrs["concat_end"] = offset
@@ -116,21 +116,20 @@ def rewrite_split(tens, arch, nng):
# be calculated from the index of the output tensor
if axis is not None:
# Get the start and end of the split
- offset_start = [0] * len(tens.shape)
- offset_end = [0] * len(tens.shape)
- for out in outputs:
+ offset_start = [0] * 4
+ for idx, out in enumerate(outputs):
if out == tens:
break
- offset_start[axis] += out.shape[axis]
+ axis_4D = axis + (4 - len(out.shape))
+ offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
# If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
if (offset_start[-1] % 16) != 0:
inp.avoid_NHCWB16 = True
-
- offset_end[axis] = offset_start[axis] + tens.shape[axis]
+ else:
+ offset_start = full_shape(4, offset_start, 0)
new_op.attrs["split_start"] = offset_start
- new_op.attrs["split_end"] = offset_end
new_op.run_on_npu = True
new_op.set_output_tensor(tens)
DebugDatabase.add_optimised(split_op, new_op)
@@ -217,6 +216,8 @@ def convert_resizebilinear_1x1_to_add(op):
# Set the add inputs
op.inputs[1] = op.inputs[0]
op.inputs[0] = tens
+ op.ifm_shapes = []
+ op.ofm_shapes = []
return op
@@ -321,13 +322,16 @@ def convert_batched_fc_shape(op, arch, nng):
ifm = op.inputs[0]
ofm = op.outputs[0]
# Check if the FC is 2D and first dimension indicates batching
- if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1:
+ # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
+ if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1:
n = ifm.shape[0]
batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
h, w = batching_split.get(n, (1, n))
prev_op = ifm.ops[0]
desired_shape = [1, h, w, ifm.shape[-1]]
+ op.ifm_shapes[0] = desired_shape
+
if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
# There is a preceding Reshape
# Compare input of prev_op and input of op, to see if prev_op can be removed
@@ -352,6 +356,8 @@ def convert_batched_fc_shape(op, arch, nng):
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
desired_shape = [1, h, w, ofm.shape[-1]]
+ op.ofm_shapes[0] = desired_shape
+
if (
len(ofm.consumer_list) == 1
and ofm.consumer_list[0] is not None
@@ -451,6 +457,7 @@ def fixup_stridedslice_output(tens, arch, nng):
new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
for idx, out_tens in enumerate(op.outputs):
+ op.ofm_shapes[idx] = new_shape_tens
reshape_in = out_tens.clone("_reshaped")
reshape_in.set_all_shapes(reshape_input_shape)
reshape_in.ops = [op]
@@ -489,7 +496,6 @@ def fixup_unpack_output(tens, arch, nng):
DebugDatabase.add_optimised(op, reshape_op)
op.outputs[idx] = reshape_in
-
return tens
@@ -582,7 +588,7 @@ def convert_conv_to_fc(op, arch, nng):
# caching/double buffering for the weights.
# (Weights dont need to be reloaded for convs when IFM H and W are 1)
if op.type == Op.Conv2DBias:
- _, h, w, _ = op.inputs[0].shape
+ _, h, w, _ = op.ifm_shapes[0]
kh, kw, _, _ = op.inputs[1].shape
if h == 1 and w == 1 and kh == 1 and kw == 1:
# Overwrite this op as a Fully Connected Op
@@ -595,6 +601,7 @@ def convert_conv_to_fc(op, arch, nng):
weight_tensor = op.inputs[1]
weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+
# The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it
# back to 4D afterwards as the next layer is expecting that shape
orig_ofm_tensor = op.outputs[0]
@@ -609,6 +616,7 @@ def convert_conv_to_fc(op, arch, nng):
reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
reshape_op.set_output_tensor(orig_ofm_tensor)
+
# Replace this ops OFM to point to the 2D tensor
op.outputs[0] = fc_ofm_tensor
# Record optimisation in debug database
@@ -651,6 +659,8 @@ def fixup_act_reorder(op, arch, nng):
prep_op = get_prepend_op(op)
if prep_op is not None:
act_op = op.clone("_reordered")
+ act_op.ifm_shapes = list(op.ifm_shapes)
+ act_op.ofm_shapes = list(op.ofm_shapes)
# There is only one input tensor, overwrite it
act_op.set_input_tensor(prep_op.inputs[0], 0)
@@ -658,6 +668,8 @@ def fixup_act_reorder(op, arch, nng):
act_op_out = act_op.inputs[0].clone("_acted")
act_op_out.quantization = op.outputs[0].quantization.clone()
act_op.set_output_tensor(act_op_out)
+ act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1)
+ act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1)
# Update the consumer list
act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -704,6 +716,15 @@ def set_tensor_equivalence(op, arch, nng):
return op
+def set_ifm_ofm_op_shapes(op, arch, nng):
+ if op.run_on_npu and op.type.needs_shapes():
+ if op.ifm_shapes or op.ofm_shapes:
+ # Shapes already set
+ return op
+ op.set_ifm_ofm_shapes()
+ return op
+
+
def convert_softmax(op, arch, nng):
if op.type == Op.Softmax and op.run_on_npu:
softmax = SoftMax(op)
@@ -839,7 +860,7 @@ def convert_lrelu_to_mul_max(op, arch):
mul_identity.add_input_tensor(identity_tens)
fm_id = ofm.clone(op.name + "_id")
mul_identity.set_output_tensor(fm_id)
- DebugDatabase.add_optimised(op, mul_alpha)
+ DebugDatabase.add_optimised(op, mul_identity)
# Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
op.type = Op.Maximum
@@ -869,6 +890,8 @@ def convert_to_lut(op, lut_values, lut_name):
quantization.zero_point = 0
tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
op.add_input_tensor(tens)
+ op.ifm_shapes.append(full_shape(4, tens.shape, 1))
+
# The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
# so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
# should be the same as the IFM
@@ -1072,10 +1095,20 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
if verbose_graph:
nng.print_graph()
+ pre_process_list = [
+ supported_operator_check,
+ set_ifm_ofm_op_shapes,
+ # TODO: memory-only Op removal
+ ]
+
+ for idx, sg in enumerate(nng.subgraphs):
+ # rewrite graph pass
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
+ )
+
op_rewrite_list = [
set_tensor_equivalence,
- supported_operator_check,
- # then do any rewrites of supported operators
convert_depthwise_to_conv,
convert_conv_to_fc,
convert_softmax,
@@ -1106,7 +1139,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# remove passthrough tensors and attempt further optimizations
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
+ nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields],
)
# Post-optimisation operator debug tracing
@@ -1125,7 +1158,11 @@ def optimise_graph_b(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# combined rewrite graph pass
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], []
+ nng,
+ sg,
+ arch,
+ [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split],
+ [set_ifm_ofm_op_shapes],
)
if verbose_graph: