aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py580
1 files changed, 279 insertions, 301 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 5f111786..bb5a9e03 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -28,6 +28,7 @@ from . import scaling
from .data_type import DataType
from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
+from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .numeric_util import clamp_sigmoid
from .numeric_util import full_shape
@@ -42,7 +43,6 @@ from .shape4d import Shape4D
from .softmax import SoftMax
from .tensor import check_quantized_tens_scaling_equal
from .tensor import create_const_tensor
-from .tensor import create_reshape_tensor
from .tensor import QuantizationParameters
from .tensor import Tensor
from .tflite_mapping import optype_to_builtintype
@@ -59,52 +59,68 @@ def remove_passthrough_tensor(tens, arch, nng):
return tens
-def rewrite_concat(tens, arch, nng):
- if len(tens.ops) == 1 and tens.ops[0].type.is_concat_op():
- concat_op = tens.ops[0]
- if tens != concat_op.outputs[0]:
- return tens # don't attempt to rewrite the min/max outputs of QuantizedConcat
+def rewrite_concat_ops(op, arch, nng):
+ if not op.run_on_npu or not op.type.is_concat_op():
+ return op
- # Not supported so leave it and run on CPU
- if not concat_op.run_on_npu:
- return tens
+ axis_4D = 0
+ ofm = op.ofm
+ ofm.ops = []
+ offset = 0
- inputs, axis = concat_op.get_concat_inputs_axis()
+ if op.type == Op.Pack:
+ # Pack is also referred to as Stack
+ axis = int(op.attrs["axis"])
+ desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
- tens.ops = []
- offset = 0
- for idx, inp in enumerate(inputs):
+ if axis >= 0:
+ axis_4D = axis + (4 - len(desired_shape))
+ else:
+ axis_4D = axis
+
+ for idx, inp in enumerate(op.inputs):
+ op.ifm_shapes[idx] = Shape4D(desired_shape)
+ if Shape4D(inp.shape) != op.ifm_shapes[idx]:
+ inp.avoid_NHCWB16 = True
+ op.type = Op.PackReshaped
+
+ inputs, axis = op.get_concat_inputs_axis()
+
+ for idx, inp in enumerate(inputs):
+ if op.type != Op.PackReshaped:
+ op.ifm_shapes[idx] = Shape4D(inp.shape)
if axis >= 0:
axis_4D = axis + (4 - len(inp.shape))
else:
axis_4D = axis
- new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
- new_op.inputs = [inp]
- new_op.outputs = [tens]
- new_op.attrs["concat_axis"] = axis_4D
- new_op.attrs["concat_start"] = offset
- offset += inp.shape[axis]
- new_op.attrs["concat_end"] = offset
- new_op.run_on_npu = True
- tens.ops.append(new_op)
- DebugDatabase.add_optimised(concat_op, new_op)
- new_op.set_ifm_ofm_shapes()
- assert tens.shape[axis] == offset
-
- # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
- # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
- # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
- # and those addresses are always 16 byte aligned due to the NHCWB16 format.
- if axis == -1 or axis == (len(tens.shape) - 1):
- for op in tens.ops:
- if op.attrs["concat_start"] % 16 != 0:
- tens.avoid_NHCWB16 = True
- break
+ new_op = Operation(Op.ConcatSliceWrite, op.name + str(idx))
+ new_op.inputs = [inp]
+ new_op.outputs = [ofm]
+ new_op.attrs["concat_axis"] = axis_4D
+ new_op.attrs["concat_start"] = offset
+ offset += op.ifm_shapes[idx].get_dim(axis_4D)
- return tens
+ new_op.attrs["concat_end"] = offset
+ new_op.run_on_npu = True
+ ofm.ops.append(new_op)
+ DebugDatabase.add_optimised(op, new_op)
+ new_op.ifm_shapes.append(op.ifm_shapes[idx].clone())
+ new_op.ofm_shapes.append(op.ofm_shapes[0].clone())
+ assert ofm.shape[axis] == offset
+
+ # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
+ # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
+ # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
+ # and those addresses are always 16 byte aligned due to the NHCWB16 format.
+ if axis == -1 or axis == (len(ofm.shape) - 1):
+ for op in ofm.ops:
+ if op.attrs["concat_start"] % 16 != 0:
+ ofm.avoid_NHCWB16 = True
+ break
+ return op
-def rewrite_split(tens, arch, nng):
+def rewrite_split_ops(tens, arch, nng):
if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
split_op = tens.ops[0]
@@ -118,20 +134,27 @@ def rewrite_split(tens, arch, nng):
tens.ops = []
new_op = Operation(Op.SplitSliceRead, split_op.name)
new_op.inputs = [inp]
+ ofm_shape_idx = 0
# For Split the offset cannot be extracted from the tensor so it has to
# be calculated from the index of the output tensor
if axis is not None:
# Get the start and end of the split
offset_start = [0] * 4
+ axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
for idx, out in enumerate(outputs):
- split_op.ofm_shapes[idx] = Shape4D(out.shape)
+ if axis_4D_list is not None:
+ axis_4D = axis_4D_list[idx]
+ else:
+ split_op.ofm_shapes[idx] = Shape4D(out.shape)
+ if axis >= 0:
+ axis_4D = axis + (4 - len(out.shape))
+ else:
+ axis_4D = axis
+
if out == tens:
+ ofm_shape_idx = idx
break
- if axis >= 0:
- axis_4D = axis + (4 - len(out.shape))
- else:
- axis_4D = axis
offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)
@@ -145,7 +168,7 @@ def rewrite_split(tens, arch, nng):
new_op.run_on_npu = True
new_op.set_output_tensor(tens)
new_op.ifm_shapes.append(Shape4D(inp.shape))
- new_op.ofm_shapes.append(Shape4D(full_shape(4, tens.shape, 1)))
+ new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx].clone())
DebugDatabase.add_optimised(split_op, new_op)
return tens
@@ -158,9 +181,9 @@ def needed_total_padding(input_size, stride, filter_size):
return total_padding
-def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims, explicit_padding):
- ypad = needed_total_padding(int(input_dims[1]), int(stride[1]), int(kernel_size[0]))
- xpad = needed_total_padding(int(input_dims[2]), int(stride[2]), int(kernel_size[1]))
+def calc_padding_and_skirt(padding_type, kernel_size, stride, input_shape, explicit_padding):
+ ypad = needed_total_padding(int(input_shape.height), int(stride[1]), int(kernel_size[0]))
+ xpad = needed_total_padding(int(input_shape.width), int(stride[2]), int(kernel_size[1]))
if padding_type == Padding.SAME:
left_pad = (xpad + 0) // 2
right_pad = (xpad + 1) // 2
@@ -184,11 +207,11 @@ def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims, explic
return padding, skirt
-def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dims, upscaling_factor):
+def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
kernel_height, kernel_width = kernel_size[0], kernel_size[1]
if padding_type == Padding.SAME:
- ypad = needed_total_padding(int(input_dims[1]) * upscaling_factor, int(stride[1]), int(kernel_height))
- xpad = needed_total_padding(int(input_dims[2]) * upscaling_factor, int(stride[2]), int(kernel_width))
+ ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
+ xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
left_pad = max(kernel_width - 1 - right_pad, 0)
@@ -225,7 +248,7 @@ def convert_resizebilinear_1x1_to_add(op):
op.name = op.name + "_add"
op.attrs["resizebilinear"] = True
# Create an input tensor filled with zeros
- shape = op.outputs[0].shape
+ shape = op.ofm_shapes[0].as_list()
tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
tens.values = np.zeros(shape)
tens.quant_values = np.zeros(shape, np.uint8)
@@ -258,8 +281,8 @@ def convert_resizebilinear_to_2x2_pool(op):
op.attrs["padding"] = Padding.SAME
op.inputs[0].resampling_mode = resampling_mode.NEAREST
- upscaled_shape = np.array(op.inputs[0].shape[1:3])
- out_shape = np.array(op.outputs[0].shape[1:3])
+ upscaled_shape = op.ifm_shape[0].get_hw_as_list()
+ out_shape = op.ofm_shape[0].get_hw_as_list()
if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
return op
@@ -276,8 +299,8 @@ def convert_resizebilinear_to_2x2_pool(op):
scaled_op.outputs = outputs
scaled_op.outputs[0].ops = [scaled_op]
else:
- shape = outputs[0].shape.copy()
- shape[1:3] = upscaled_shape[0:2]
+ shape = op.ofm_shapes[0].as_list()
+ shape[1:3] = upscaled_shape
out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
out_tens.quantization = op.outputs[0].quantization.clone()
out_tens.quantization.quant_min = np.iinfo(np.int16).min
@@ -300,11 +323,11 @@ def convert_resizebilinear_to_2x2_pool(op):
def fixup_resizebilinear(op, arch, nng):
if op.type == Op.ResizeBilinear and op.run_on_npu:
- if op.inputs[0].shape == op.outputs[0].shape:
+ if op.ifm_shapes[0] == op.ofm_shapes[0]:
# Bypass nop resizebilinear
op.inputs = op.inputs[:1]
op.type = Op.Identity
- elif op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
+ elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
convert_resizebilinear_1x1_to_add(op)
else:
convert_resizebilinear_to_2x2_pool(op)
@@ -321,109 +344,26 @@ def convert_nop_split_to_identity(op, arch, nng):
return op
-def fixup_fully_connected_input(op, arch, nng):
- if op.type == Op.FullyConnected:
- inp = op.inputs[0]
- weights = op.inputs[1]
-
- n_in_elems = weights.shape[-2]
- elms = inp.elements()
- batch_size = elms // n_in_elems
- assert batch_size * n_in_elems == elms
-
- desired_shape = [batch_size, n_in_elems]
- if inp.shape != desired_shape:
- # mismatch, insert a reshape to fix this.
- op.set_input_tensor(create_reshape_tensor(inp, desired_shape), 0)
-
- return op
-
-
def convert_batched_fc_shape(op, arch, nng):
if op.type == Op.FullyConnected:
- ifm = op.inputs[0]
- ofm = op.outputs[0]
- # Check if the FC is 2D and first dimension indicates batching
- # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
- if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1:
- n = ifm.shape[0]
+ # Check if the first dimension indicates batching
+ if op.ifm_shapes[0].batch > 1:
batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
+ n = op.ifm_shapes[0].batch
h, w = batching_split.get(n, (1, n))
+ op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
- prev_op = ifm.ops[0]
- desired_shape = [1, h, w, ifm.shape[-1]]
- op.ifm_shapes[0] = Shape4D(desired_shape)
-
- if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
- # There is a preceding Reshape
- # Compare input of prev_op and input of op, to see if prev_op can be removed
- ifm_prev_op = prev_op.inputs[0]
- if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm):
- # prev_op can be removed
- op.set_input_tensor(ifm_prev_op, 0)
- else:
- op.inputs[0].set_all_shapes(desired_shape)
- prev_op.set_input_tensor(
- create_const_tensor(prev_op.inputs[1].name, [1], DataType.int32, desired_shape), 1
- )
- prev_op.attrs["new_shape"] = desired_shape
- else:
- # Add reshape op to the input if there is no preceding reshape
- ifm.consumer_list.remove(op)
- op.set_input_tensor(create_reshape_tensor(ifm, desired_shape), 0)
+ op.ifm.avoid_NHCWB16 = True
# Reshape Weights to be 4D. IO becomes HWIO
weight_tensor = op.inputs[1]
weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
- desired_shape = [1, h, w, ofm.shape[-1]]
- op.ofm_shapes[0] = Shape4D(desired_shape)
-
- if (
- len(ofm.consumer_list) == 1
- and ofm.consumer_list[0] is not None
- and ofm.consumer_list[0].type == Op.Reshape
- ):
- # There is a subsequent Reshape
- # Compare desired shape and output of consumer op, to see if consumer op can be removed
- ofm_cons_op = ofm.consumer_list[0].outputs[0]
- if desired_shape == ofm_cons_op.shape and check_quantized_tens_scaling_equal(ofm, ofm_cons_op):
- op.outputs[0] = ofm_cons_op
- op.outputs[0].ops = [op]
- else:
- op.outputs[0].set_all_shapes(desired_shape)
- else:
- # Add reshape op to the output
- op.set_output_tensor(create_reshape_tensor(ofm, desired_shape, False))
- return op
-
-
-def fixup_pack_input(op, arch, nng):
- if op.type == Op.Pack:
- # Pack is also referred to as Stack
- # Requires the rewrite_concat function to be called on the op afterwards
- axis = int(op.attrs["axis"])
- desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
-
- # Construct 1 shape tensor to be used by all inserted reshape ops
- new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, desired_shape)
-
- for idx, inp in enumerate(op.inputs):
- reshape_out = inp.clone("_reshaped")
- reshape_out.set_all_shapes(desired_shape)
-
- reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
- reshape_op.attrs["new_shape"] = desired_shape
- reshape_op.inputs = [inp, new_shape_tens]
- reshape_op.set_output_tensor(reshape_out)
- reshape_op.set_ifm_ofm_shapes()
- DebugDatabase.add_optimised(op, reshape_op)
-
- op.inputs[idx] = reshape_out
-
- op.type = Op.PackReshaped
-
+ n = op.ofm_shapes[0].batch
+ h, w = batching_split.get(n, (1, n))
+ op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
+ op.ofm.avoid_NHCWB16 = True
return op
@@ -441,12 +381,19 @@ def unfuse_activation_function(op, arch, nng):
return op
-def fixup_stridedslice_output(tens, arch, nng):
- op = tens.ops[0]
- if op.run_on_npu and op.type == Op.StridedSlice:
- reshape_input_shape = tens.shape
- new_axis_mask = op.attrs["new_axis_mask"]
- shrink_axis_mask = op.attrs["shrink_axis_mask"]
+def rewrite_stridedslice_output(op, arch, nng):
+ if not op.run_on_npu or op.type != Op.StridedSlice:
+ return op
+
+ new_axis_mask = op.attrs["new_axis_mask"]
+ shrink_axis_mask = op.attrs["shrink_axis_mask"]
+
+ if shrink_axis_mask == 0 and new_axis_mask == 0:
+ return op
+
+ axis_4D = [0] * len(op.outputs)
+ for idx, out_tens in enumerate(op.outputs):
+ output_shape = list(out_tens.shape)
if shrink_axis_mask != 0:
n = 0
@@ -456,10 +403,16 @@ def fixup_stridedslice_output(tens, arch, nng):
n += 1
shrink_axis_mask &= shrink_axis_mask - 1
axis = int(math.log2(prev_mask - shrink_axis_mask))
- reshape_input_shape = reshape_input_shape[:axis] + [1] + reshape_input_shape[axis:]
+ output_shape = output_shape[:axis] + [1] + output_shape[axis:]
- assert len(tens.shape) == (len(op.inputs[0].shape) - n)
+ assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
op.attrs["shrink_axis_mask"] = 0
+ if axis >= 0:
+ axis_4D[idx] = axis + (4 - len(output_shape))
+ else:
+ axis_4D[idx] = axis
+ op.ofm_shapes[idx] = Shape4D(output_shape)
+
elif new_axis_mask != 0:
n = 0
axis = 0
@@ -468,77 +421,62 @@ def fixup_stridedslice_output(tens, arch, nng):
n += 1
new_axis_mask &= new_axis_mask - 1
axis = int(math.log2(prev_mask - new_axis_mask))
- reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :]
+ output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
new_axis_mask >>= 1
- assert len(tens.shape) == (len(op.inputs[0].shape) + n)
+ assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
op.attrs["new_axis_mask"] = 0
- else:
- # Equal Rank StridedSlice, no need to insert reshape
- return tens
-
- # Construct 1 shape tensor to be used by all inserted reshape ops
- new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
-
- for idx, out_tens in enumerate(op.outputs):
- op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape)
- reshape_in = out_tens.clone("_reshaped")
- reshape_in.set_all_shapes(reshape_input_shape)
- reshape_in.ops = [op]
-
- reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
- reshape_op.attrs["new_shape"] = reshape_input_shape
- reshape_op.inputs = [reshape_in, new_shape_tens]
- reshape_op.set_output_tensor(out_tens)
- reshape_op.set_ifm_ofm_shapes()
+ if axis >= 0:
+ axis_4D[idx] = axis + (4 - len(output_shape))
+ else:
+ axis_4D[idx] = axis
+ op.ofm_shapes[idx] = Shape4D(output_shape)
- op.outputs[idx] = reshape_in
+ if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
+ out_tens.avoid_NHCWB16 = True
- return tens
+ op.attrs["split_axis_4D"] = axis_4D
+ return op
-def fixup_unpack_output(tens, arch, nng):
- op = tens.ops[0]
+def rewrite_unpack_output(op, arch, nng):
+ tens = op.outputs[0]
if op.run_on_npu and op.type == Op.Unpack:
# Unpack is also referred to as Unstack
- # Requires the rewrite_split function to be called on the op afterwards
axis = int(op.attrs["axis"])
op.type = Op.UnpackReshaped
- reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
+ desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
- # Construct 1 shape tensor to be used by all inserted reshape ops
- new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
+ if axis >= 0:
+ axis_4D = axis + (4 - len(desired_output_shape))
+ else:
+ axis_4D = axis
+ axis_4D_list = [0] * len(op.outputs)
for idx, out_tens in enumerate(op.outputs):
- reshape_in = out_tens.clone("_reshaped")
- reshape_in.set_all_shapes(reshape_input_shape)
- reshape_in.ops = [op]
-
- reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
- reshape_op.attrs["new_shape"] = reshape_input_shape
- reshape_op.inputs = [reshape_in, new_shape_tens]
- reshape_op.set_output_tensor(out_tens)
- reshape_op.set_ifm_ofm_shapes()
- DebugDatabase.add_optimised(op, reshape_op)
-
- op.outputs[idx] = reshape_in
- return tens
+ op.ofm_shapes[idx] = Shape4D(desired_output_shape)
+ axis_4D_list[idx] = axis_4D
+ if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
+ out_tens.avoid_NHCWB16 = True
+
+ op.attrs["split_axis_4D"] = axis_4D_list
+ return op
def add_padding_fields(op, arch, nng):
if op.run_on_npu:
if "padding" in op.attrs:
+ input_shape = op.ifm_shapes[0]
+ output_shape = op.ofm_shapes[0]
if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
kernel_size = op.inputs[1].shape[:2]
- input_shape = op.inputs[0].shape
elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
kernel_size = op.attrs["ksize"][1:3]
- input_shape = op.inputs[0].shape
else:
raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
if op.type == Op.Conv2DBackpropInputSwitchedBias:
- upscaling_factor = op.outputs[0].shape[1] // input_shape[1]
+ upscaling_factor = output_shape.height // input_shape.height
padding, skirt = calc_upscaled_padding_and_skirt(
op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
)
@@ -582,10 +520,10 @@ def convert_depthwise_to_conv(op, arch, nng):
# switch of the operator type (and weight order)
if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
- ifm_tensor = op.inputs[0]
+ ifm_shape = op.ifm_shapes[0]
weight_tensor = op.inputs[1]
- ofm_tensor = op.outputs[0]
- if (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"]):
+ ofm_shape = op.ofm_shapes[0]
+ if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
# Change op type to Conv2d
op.type = Op.Conv2DBias
del op.attrs["channel_multiplier"]
@@ -596,7 +534,7 @@ def convert_depthwise_to_conv(op, arch, nng):
else:
raise UnsupportedFeatureError(
f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
- f" ifm channels = {ifm_tensor.shape[3]}, ofm channels = {ofm_tensor.shape[3]}",
+ f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
)
DebugDatabase.add_optimised(op, op)
return op
@@ -620,17 +558,15 @@ def optimise_strided_conv(op, arch, nng):
op.type == Op.Conv2DBias
and op.op_index == 0
and stride_x == 2
- and len(ifm_tensor.shape) == 4
- and ifm_tensor.shape[3] <= 4
- and ifm_tensor.shape[2] % 2 == 0
+ and op.ifm_shapes[0].depth <= 4
+ and op.ifm_shapes[0].width % 2 == 0
and weight_tensor is not None
and weight_tensor.shape[1] >= 2
):
+ ifm_shape = op.ifm_shapes[0]
# IFM
- ifm_reshaped = create_reshape_tensor(
- ifm_tensor, [ifm_tensor.shape[0], ifm_tensor.shape[1], ifm_tensor.shape[2] // 2, ifm_tensor.shape[3] * 2]
- )
- op.set_input_tensor(ifm_reshaped, 0)
+ op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
+ op.ifm.avoid_NHCWB16 = True
# Weights
weight_shape = weight_tensor.shape
@@ -657,8 +593,6 @@ def optimise_strided_conv(op, arch, nng):
stride_x = 1
op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
- op.set_ifm_ofm_shapes()
-
return op
@@ -683,27 +617,6 @@ def convert_conv_to_fc(op, arch, nng):
weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
- # The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it
- # back to 4D afterwards as the next layer is expecting that shape
- orig_ofm_tensor = op.outputs[0]
- # Reshape this ops output to be 2D: {(N*H*W), C} (We know N H and W are all 1 so this becomes {1, C})
- fc_ofm_tensor = orig_ofm_tensor.clone("_fc")
- fc_ofm_tensor.set_all_shapes([1, fc_ofm_tensor.shape[-1]])
- fc_ofm_tensor.ops = [op]
- # Add a reshape after the new OFM to convert it back to the original 4D shape
- reshape_name = op.name + "_reshape"
- new_shape_tens = create_const_tensor(reshape_name + "_shape", [1], DataType.int32, orig_ofm_tensor.shape)
- reshape_op = Operation(Op.Reshape, reshape_name)
- reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
- reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
- reshape_op.set_output_tensor(orig_ofm_tensor)
- reshape_op.set_ifm_ofm_shapes()
-
- # Replace this ops OFM to point to the 2D tensor
- op.outputs[0] = fc_ofm_tensor
- op.set_ifm_ofm_shapes()
- # Record optimisation in debug database
- DebugDatabase.add_optimised(op, reshape_op)
DebugDatabase.add_optimised(op, op)
return op
@@ -722,14 +635,6 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
# Tidy up and assign the ifm and ofm to the new op
ifm.consumer_list.remove(op)
- # if not 4d, reshape ifm/ofm
- if len(ifm.shape) < 4:
- ifm_shaped = create_reshape_tensor(ifm, full_shape(4, ifm.shape, 1))
- ifm = ifm_shaped
- if len(ofm.shape) < 4:
- ofm_shaped = create_reshape_tensor(ofm, full_shape(4, ofm.shape, 1), False)
- ofm = ofm_shaped
-
relu_fused_op.add_input_tensor(ifm)
relu_fused_op.set_output_tensor(ofm)
relu_fused_op.set_ifm_ofm_shapes()
@@ -737,6 +642,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
return op
+# TODO remove if mem only ops can all be removed
# Reorder activation op if it's after the memory only operations
def fixup_act_reorder(op, arch, nng):
if op.type.is_relu_op() or op.type in (Op.Sigmoid, Op.Tanh):
@@ -752,8 +658,8 @@ def fixup_act_reorder(op, arch, nng):
act_op_out = act_op.inputs[0].clone("_acted")
act_op_out.quantization = op.outputs[0].quantization.clone()
act_op.set_output_tensor(act_op_out)
- act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape)
- act_op.ofm_shapes[0] = Shape4D(act_op_out.shape)
+ act_op.ofm_shapes[0] = act_op.ifm_shapes[0].clone()
+ act_op.ifm_shapes[0] = prep_op.ifm_shapes[0].clone()
# Update the consumer list
act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -1078,39 +984,94 @@ def convert_tanh_sigmoid_to_lut(op, arch, nng):
return op
-def remove_unwanted_reshapes(op, arch, nng):
- # Try to remove reshapes enclosing ElementWise operator with only one non-constant input
- if not op.run_on_npu or not op.type.is_elementwise_op():
- return op
+def remove_reshapes(op, arch):
+ if op.run_on_npu and op.type == Op.Reshape:
+ ofm = op.ofm
+ ifm = op.ifm
- # Check if the ElementWise operator only have one non-constant input
- non_const_tens = [x for x in op.inputs if x.ops[0].type != Op.Const]
- if len(non_const_tens) != 1:
- return op
- ifm = non_const_tens[0]
+ # Check if quantization is the same in the input and output for the reshape ops
+ if not check_quantized_tens_scaling_equal(ifm, ofm):
+ # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
+ # In order to remove this reshape either quantization properties need to be moved to Operator,
+ # or the reshape need to be replace with a NOP.
+ return
+
+ # Check if ifm is a sg input
+ if ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
+ # put the reshape on CPU
+ op.run_on_npu = False
+ return
+
+ # Check if Reshape ifm/ofm are network ifm/ofm
+ ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
+ ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
+
+ if ifm_is_sg_ofm and ofm_is_sg_ofm:
+ # Both ifm and ofm are sg outputs,add reshape to the ifm and put it on CPU
+ ifm_cons_list_copy = ifm.consumer_list.copy()
+ ifm_ops_copy = ifm.ops.copy()
+ for ifm_cons in ifm_cons_list_copy:
+ if ifm_cons is None:
+ # Create a reshape op with ifm as output
+ name = ifm.name + "_cpu_reshape"
+ reshape_ifm = ifm.clone()
+ reshape_op = Operation(Op.Reshape, name)
+ reshape_op.attrs["new_shape"] = ifm.shape
+ reshape_op.add_input_tensor(reshape_ifm)
+ reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, ifm.shape))
+ reshape_op.set_output_tensor(ifm)
+ reshape_op.set_ifm_ofm_shapes()
+ reshape_op.run_on_npu = False
+ reshape_op.ofm.ops = [reshape_op]
+ reshape_op.ofm.consumer_list = [None]
+
+ # Set reshape_ifm producers
+ for prev_op in ifm_ops_copy:
+ prev_op.outputs = [reshape_ifm]
+ reshape_ifm.ops.append(prev_op)
+
+ # Set reshape_ifm consumers
+ for ifm_cons in ifm_cons_list_copy:
+ if ifm_cons is not None:
+ for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+ if cons_ifm == ifm:
+ ifm_cons.set_input_tensor(reshape_ifm, ifm_idx)
+
+ ifm = reshape_ifm
+ break
+ ifm_is_sg_ofm = False
+
+ if ofm_is_sg_ofm:
+ # Bypassed by replacing ifm with ofm
+ ofm.ops = []
+ for prev_op in ifm.ops:
+ prev_op.outputs = [ofm]
+ ofm.ops.append(prev_op)
+
+ # All ifm consumers need to use ofm as input
+ for ifm_cons in ifm.consumer_list:
+ for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+ if cons_ifm == ifm:
+ ifm_cons.set_input_tensor(ofm, ifm_idx)
+ if op.ifm_shapes[0] != op.ofm_shapes[0]:
+ ofm.avoid_NHCWB16 = True
+ else:
+ # Bypassed Reshape by replacing ofm with ifm
+ for cons in ofm.consumer_list:
+ for ifm_idx, cons_ifm in enumerate(cons.inputs):
+ if cons_ifm == ofm:
+ cons.set_input_tensor(ifm, ifm_idx)
+ if op.ifm_shapes[0] != op.ofm_shapes[0]:
+ ifm.avoid_NHCWB16 = True
- # Check if operation is enclosed by Reshapes that can be removed
- ofm = op.outputs[0]
- prev_op = ifm.ops[0]
- if (
- len(ifm.consumer_list) == 1
- and prev_op.type == Op.Reshape
- and len(ofm.consumer_list) == 1
- and ofm.consumer_list[0].type == Op.Reshape
- ):
- # Operation is enclosed by reshapes, check if they can be removed
- prev_op_ifm, prev_op_ofm = prev_op.get_ifm_ofm()
- cons_op = ofm.consumer_list[0]
- cons_op_ifm = ofm
- cons_op_ofm = cons_op.outputs[0]
- if len(prev_op_ifm.shape) == len(cons_op_ofm.shape):
- # Check if quantization is the same in the input and output for the reshape ops
- if check_quantized_tens_scaling_equal(prev_op_ifm, prev_op_ofm) and check_quantized_tens_scaling_equal(
- cons_op_ifm, cons_op_ofm
- ):
- op.set_input_tensor(prev_op_ifm, 0)
- op.set_output_tensor(cons_op_ofm)
- return op
+
+def check_reshapes(op, arch):
+ if op.run_on_npu and op.type == Op.Reshape:
+ ofm = op.ofm
+
+ if check_quantized_tens_scaling_equal(op.ifm, ofm):
+ # Reshape should have been removed
+ raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
def fuse_activation_function_with_prev(op, arch, nng):
@@ -1174,13 +1135,19 @@ def optimise_pad(op, arch, nng):
def add_attrs_to_resizebilinear(op, arch, nng):
if op.type == Op.ResizeBilinear and op.run_on_npu:
input_tensor = op.inputs[0]
- upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2]
- out_shape = op.outputs[0].shape[1:3]
- if not op.attrs["align_corners"] and out_shape == upscaled_shape:
+ input_shape = op.ifm_shapes[0]
+ upscaled_height = input_shape.height * 2
+ upscaled_width = input_shape.width * 2
+ out_shape = op.ofm_shapes[0]
+ if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
# this means the output is supposed to be a x2 upscale,
# so we need to do SAME padding
op.attrs["padding"] = Padding.SAME
- elif op.attrs["align_corners"] and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
+ elif (
+ op.attrs["align_corners"]
+ and out_shape.height == (upscaled_height - 1)
+ and out_shape.width == (upscaled_width - 1)
+ ):
# here we can just run the avg pool without padding and
# produce a (M * 2 - 1, N * 2 - 1) sized output
op.attrs["padding"] = Padding.VALID
@@ -1229,26 +1196,52 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
)
+ # Handle Concat Ops
+ for idx, sg in enumerate(nng.subgraphs):
+ # rewrite graph pass
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [], [rewrite_concat_ops], rewrite_unsupported=False,
+ )
+
+ # Handle Split Ops
+ for idx, sg in enumerate(nng.subgraphs):
+ # rewrite graph pass
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng,
+ sg,
+ arch,
+ [],
+ [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
+ rewrite_unsupported=False,
+ )
+
+ for idx, sg in enumerate(nng.subgraphs):
+ # rewrite graph pass
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
+ )
+
+ # Removal of reshapes
+ for sg in nng.subgraphs:
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+ sg.refresh_after_modification()
+
op_rewrite_list = [
set_tensor_equivalence,
convert_depthwise_to_conv,
convert_conv_to_fc,
convert_softmax,
optimise_strided_conv,
- fixup_fully_connected_input,
convert_batched_fc_shape,
- fixup_pack_input,
unfuse_activation_function,
fixup_conv2d_backprop,
fixup_relus_with_differing_ifm_ofm_scaling,
fixup_act_reorder,
- fixup_elementwise_with_scalars,
+ fixup_elementwise_with_scalars, # TODO Move to early stage?
reorder_depthwise_weights,
fixup_resizebilinear,
fixup_bias_tensors,
- convert_nop_split_to_identity,
convert_mul_max_to_abs_or_lrelu,
- remove_unwanted_reshapes,
convert_lrelu,
convert_tanh_sigmoid_to_lut,
]
@@ -1269,24 +1262,9 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
[fuse_activation_function_with_prev, optimise_pad, add_padding_fields],
)
- # Post-optimisation operator debug tracing
+ # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph
for sg in nng.subgraphs:
- rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [_record_optimised])
-
- if verbose_graph:
- nng.print_graph()
- return nng
-
-
-def optimise_graph_b(nng, arch, verbose_graph=False):
- if verbose_graph:
- nng.print_graph()
-
- for idx, sg in enumerate(nng.subgraphs):
- # combined rewrite graph pass
- nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [],
- )
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised])
if verbose_graph:
nng.print_graph()