From 3a26920b7cd302364d68830eb6e374311ce17f22 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Thu, 21 Jan 2021 08:28:55 +0100 Subject: MLBEDSW-3772 Reshape removal -Removed reshapes in the original graph -Removed the addition of reshapes to the optimized graph -Reshapes with different ifm/ofm quantisation will remain Signed-off-by: Patrik Gustavsson Change-Id: I94862be53dac0d7434815e2aee5ca678228495f8 --- ethosu/vela/compiler_driver.py | 3 - ethosu/vela/debug_database.py | 12 +- ethosu/vela/graph_optimiser.py | 580 ++++++++++----------- ethosu/vela/high_level_command_stream_generator.py | 6 +- ethosu/vela/high_level_command_to_npu_op.py | 19 +- ethosu/vela/npu_performance.py | 56 +- ethosu/vela/operation.py | 2 - ethosu/vela/operation_util.py | 85 ++- ethosu/vela/pass_packing.py | 6 +- ethosu/vela/shape4d.py | 3 + ethosu/vela/softmax.py | 61 ++- ethosu/vela/tensor.py | 105 ++-- ethosu/vela/test/test_graph_optimiser.py | 131 ++++- ethosu/vela/test/testutil.py | 5 +- 14 files changed, 634 insertions(+), 440 deletions(-) diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py index 78d7f12a..3d4f7584 100644 --- a/ethosu/vela/compiler_driver.py +++ b/ethosu/vela/compiler_driver.py @@ -146,9 +146,6 @@ def compiler_driver(nng, arch, options, scheduler_options): if options.verbose_quantization: nng.print_graph_with_tensor_quantization() - nng = graph_optimiser.optimise_graph_b(nng, arch, options.verbose_graph) - assert verify_graph_health(nng) - nng = mark_tensors.mark_tensor_purpose(nng, arch, options.verbose_tensor_purpose) assert verify_graph_health(nng) nng = insert_dma.insert_dma_commands(nng, arch, options.verbose_graph) diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py index 006348c0..adf03627 100644 --- a/ethosu/vela/debug_database.py +++ b/ethosu/vela/debug_database.py @@ -25,6 +25,7 @@ import lxml.etree as xml from . import numeric_util from .operation import Operation +from .shape4d import Shape4D class DebugDatabase: @@ -77,7 +78,10 @@ class DebugDatabase: src_uid = cls._sourceUID[parent] uid = len(cls._optimisedUID) cls._optimisedUID[op] = (uid, src_uid) - ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1) + if len(op.ofm_shapes) == 0: + ofm_shape = Shape4D(op.outputs[0].shape) + else: + ofm_shape = op.ofm_shapes[0] cls._optimisedTable.append( [ uid, @@ -85,9 +89,9 @@ class DebugDatabase: str(op.type), op.kernel.width, op.kernel.height, - ofm_shape[-2], - ofm_shape[-3], - ofm_shape[-1], + ofm_shape.width, + ofm_shape.height, + ofm_shape.depth, ] ) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 5f111786..bb5a9e03 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -28,6 +28,7 @@ from . import scaling from .data_type import DataType from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError +from .errors import VelaError from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .numeric_util import clamp_sigmoid from .numeric_util import full_shape @@ -42,7 +43,6 @@ from .shape4d import Shape4D from .softmax import SoftMax from .tensor import check_quantized_tens_scaling_equal from .tensor import create_const_tensor -from .tensor import create_reshape_tensor from .tensor import QuantizationParameters from .tensor import Tensor from .tflite_mapping import optype_to_builtintype @@ -59,52 +59,68 @@ def remove_passthrough_tensor(tens, arch, nng): return tens -def rewrite_concat(tens, arch, nng): - if len(tens.ops) == 1 and tens.ops[0].type.is_concat_op(): - concat_op = tens.ops[0] - if tens != concat_op.outputs[0]: - return tens # don't attempt to rewrite the min/max outputs of QuantizedConcat +def rewrite_concat_ops(op, arch, nng): + if not op.run_on_npu or not op.type.is_concat_op(): + return op - # Not supported so leave it and run on CPU - if not concat_op.run_on_npu: - return tens + axis_4D = 0 + ofm = op.ofm + ofm.ops = [] + offset = 0 - inputs, axis = concat_op.get_concat_inputs_axis() + if op.type == Op.Pack: + # Pack is also referred to as Stack + axis = int(op.attrs["axis"]) + desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:] - tens.ops = [] - offset = 0 - for idx, inp in enumerate(inputs): + if axis >= 0: + axis_4D = axis + (4 - len(desired_shape)) + else: + axis_4D = axis + + for idx, inp in enumerate(op.inputs): + op.ifm_shapes[idx] = Shape4D(desired_shape) + if Shape4D(inp.shape) != op.ifm_shapes[idx]: + inp.avoid_NHCWB16 = True + op.type = Op.PackReshaped + + inputs, axis = op.get_concat_inputs_axis() + + for idx, inp in enumerate(inputs): + if op.type != Op.PackReshaped: + op.ifm_shapes[idx] = Shape4D(inp.shape) if axis >= 0: axis_4D = axis + (4 - len(inp.shape)) else: axis_4D = axis - new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx)) - new_op.inputs = [inp] - new_op.outputs = [tens] - new_op.attrs["concat_axis"] = axis_4D - new_op.attrs["concat_start"] = offset - offset += inp.shape[axis] - new_op.attrs["concat_end"] = offset - new_op.run_on_npu = True - tens.ops.append(new_op) - DebugDatabase.add_optimised(concat_op, new_op) - new_op.set_ifm_ofm_shapes() - assert tens.shape[axis] == offset - - # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a - # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte - # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0 - # and those addresses are always 16 byte aligned due to the NHCWB16 format. - if axis == -1 or axis == (len(tens.shape) - 1): - for op in tens.ops: - if op.attrs["concat_start"] % 16 != 0: - tens.avoid_NHCWB16 = True - break + new_op = Operation(Op.ConcatSliceWrite, op.name + str(idx)) + new_op.inputs = [inp] + new_op.outputs = [ofm] + new_op.attrs["concat_axis"] = axis_4D + new_op.attrs["concat_start"] = offset + offset += op.ifm_shapes[idx].get_dim(axis_4D) - return tens + new_op.attrs["concat_end"] = offset + new_op.run_on_npu = True + ofm.ops.append(new_op) + DebugDatabase.add_optimised(op, new_op) + new_op.ifm_shapes.append(op.ifm_shapes[idx].clone()) + new_op.ofm_shapes.append(op.ofm_shapes[0].clone()) + assert ofm.shape[axis] == offset + + # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a + # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte + # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0 + # and those addresses are always 16 byte aligned due to the NHCWB16 format. + if axis == -1 or axis == (len(ofm.shape) - 1): + for op in ofm.ops: + if op.attrs["concat_start"] % 16 != 0: + ofm.avoid_NHCWB16 = True + break + return op -def rewrite_split(tens, arch, nng): +def rewrite_split_ops(tens, arch, nng): if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack: split_op = tens.ops[0] @@ -118,20 +134,27 @@ def rewrite_split(tens, arch, nng): tens.ops = [] new_op = Operation(Op.SplitSliceRead, split_op.name) new_op.inputs = [inp] + ofm_shape_idx = 0 # For Split the offset cannot be extracted from the tensor so it has to # be calculated from the index of the output tensor if axis is not None: # Get the start and end of the split offset_start = [0] * 4 + axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice for idx, out in enumerate(outputs): - split_op.ofm_shapes[idx] = Shape4D(out.shape) + if axis_4D_list is not None: + axis_4D = axis_4D_list[idx] + else: + split_op.ofm_shapes[idx] = Shape4D(out.shape) + if axis >= 0: + axis_4D = axis + (4 - len(out.shape)) + else: + axis_4D = axis + if out == tens: + ofm_shape_idx = idx break - if axis >= 0: - axis_4D = axis + (4 - len(out.shape)) - else: - axis_4D = axis offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D) @@ -145,7 +168,7 @@ def rewrite_split(tens, arch, nng): new_op.run_on_npu = True new_op.set_output_tensor(tens) new_op.ifm_shapes.append(Shape4D(inp.shape)) - new_op.ofm_shapes.append(Shape4D(full_shape(4, tens.shape, 1))) + new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx].clone()) DebugDatabase.add_optimised(split_op, new_op) return tens @@ -158,9 +181,9 @@ def needed_total_padding(input_size, stride, filter_size): return total_padding -def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims, explicit_padding): - ypad = needed_total_padding(int(input_dims[1]), int(stride[1]), int(kernel_size[0])) - xpad = needed_total_padding(int(input_dims[2]), int(stride[2]), int(kernel_size[1])) +def calc_padding_and_skirt(padding_type, kernel_size, stride, input_shape, explicit_padding): + ypad = needed_total_padding(int(input_shape.height), int(stride[1]), int(kernel_size[0])) + xpad = needed_total_padding(int(input_shape.width), int(stride[2]), int(kernel_size[1])) if padding_type == Padding.SAME: left_pad = (xpad + 0) // 2 right_pad = (xpad + 1) // 2 @@ -184,11 +207,11 @@ def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims, explic return padding, skirt -def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dims, upscaling_factor): +def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor): kernel_height, kernel_width = kernel_size[0], kernel_size[1] if padding_type == Padding.SAME: - ypad = needed_total_padding(int(input_dims[1]) * upscaling_factor, int(stride[1]), int(kernel_height)) - xpad = needed_total_padding(int(input_dims[2]) * upscaling_factor, int(stride[2]), int(kernel_width)) + ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height)) + xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width)) right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0) bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0) left_pad = max(kernel_width - 1 - right_pad, 0) @@ -225,7 +248,7 @@ def convert_resizebilinear_1x1_to_add(op): op.name = op.name + "_add" op.attrs["resizebilinear"] = True # Create an input tensor filled with zeros - shape = op.outputs[0].shape + shape = op.ofm_shapes[0].as_list() tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add") tens.values = np.zeros(shape) tens.quant_values = np.zeros(shape, np.uint8) @@ -258,8 +281,8 @@ def convert_resizebilinear_to_2x2_pool(op): op.attrs["padding"] = Padding.SAME op.inputs[0].resampling_mode = resampling_mode.NEAREST - upscaled_shape = np.array(op.inputs[0].shape[1:3]) - out_shape = np.array(op.outputs[0].shape[1:3]) + upscaled_shape = op.ifm_shape[0].get_hw_as_list() + out_shape = op.ofm_shape[0].get_hw_as_list() if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all(): return op @@ -276,8 +299,8 @@ def convert_resizebilinear_to_2x2_pool(op): scaled_op.outputs = outputs scaled_op.outputs[0].ops = [scaled_op] else: - shape = outputs[0].shape.copy() - shape[1:3] = upscaled_shape[0:2] + shape = op.ofm_shapes[0].as_list() + shape[1:3] = upscaled_shape out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count)) out_tens.quantization = op.outputs[0].quantization.clone() out_tens.quantization.quant_min = np.iinfo(np.int16).min @@ -300,11 +323,11 @@ def convert_resizebilinear_to_2x2_pool(op): def fixup_resizebilinear(op, arch, nng): if op.type == Op.ResizeBilinear and op.run_on_npu: - if op.inputs[0].shape == op.outputs[0].shape: + if op.ifm_shapes[0] == op.ofm_shapes[0]: # Bypass nop resizebilinear op.inputs = op.inputs[:1] op.type = Op.Identity - elif op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1: + elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1: convert_resizebilinear_1x1_to_add(op) else: convert_resizebilinear_to_2x2_pool(op) @@ -321,109 +344,26 @@ def convert_nop_split_to_identity(op, arch, nng): return op -def fixup_fully_connected_input(op, arch, nng): - if op.type == Op.FullyConnected: - inp = op.inputs[0] - weights = op.inputs[1] - - n_in_elems = weights.shape[-2] - elms = inp.elements() - batch_size = elms // n_in_elems - assert batch_size * n_in_elems == elms - - desired_shape = [batch_size, n_in_elems] - if inp.shape != desired_shape: - # mismatch, insert a reshape to fix this. - op.set_input_tensor(create_reshape_tensor(inp, desired_shape), 0) - - return op - - def convert_batched_fc_shape(op, arch, nng): if op.type == Op.FullyConnected: - ifm = op.inputs[0] - ofm = op.outputs[0] - # Check if the FC is 2D and first dimension indicates batching - # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete - if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1: - n = ifm.shape[0] + # Check if the first dimension indicates batching + if op.ifm_shapes[0].batch > 1: batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)} + n = op.ifm_shapes[0].batch h, w = batching_split.get(n, (1, n)) + op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth]) - prev_op = ifm.ops[0] - desired_shape = [1, h, w, ifm.shape[-1]] - op.ifm_shapes[0] = Shape4D(desired_shape) - - if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape: - # There is a preceding Reshape - # Compare input of prev_op and input of op, to see if prev_op can be removed - ifm_prev_op = prev_op.inputs[0] - if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm): - # prev_op can be removed - op.set_input_tensor(ifm_prev_op, 0) - else: - op.inputs[0].set_all_shapes(desired_shape) - prev_op.set_input_tensor( - create_const_tensor(prev_op.inputs[1].name, [1], DataType.int32, desired_shape), 1 - ) - prev_op.attrs["new_shape"] = desired_shape - else: - # Add reshape op to the input if there is no preceding reshape - ifm.consumer_list.remove(op) - op.set_input_tensor(create_reshape_tensor(ifm, desired_shape), 0) + op.ifm.avoid_NHCWB16 = True # Reshape Weights to be 4D. IO becomes HWIO weight_tensor = op.inputs[1] weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0) weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) - desired_shape = [1, h, w, ofm.shape[-1]] - op.ofm_shapes[0] = Shape4D(desired_shape) - - if ( - len(ofm.consumer_list) == 1 - and ofm.consumer_list[0] is not None - and ofm.consumer_list[0].type == Op.Reshape - ): - # There is a subsequent Reshape - # Compare desired shape and output of consumer op, to see if consumer op can be removed - ofm_cons_op = ofm.consumer_list[0].outputs[0] - if desired_shape == ofm_cons_op.shape and check_quantized_tens_scaling_equal(ofm, ofm_cons_op): - op.outputs[0] = ofm_cons_op - op.outputs[0].ops = [op] - else: - op.outputs[0].set_all_shapes(desired_shape) - else: - # Add reshape op to the output - op.set_output_tensor(create_reshape_tensor(ofm, desired_shape, False)) - return op - - -def fixup_pack_input(op, arch, nng): - if op.type == Op.Pack: - # Pack is also referred to as Stack - # Requires the rewrite_concat function to be called on the op afterwards - axis = int(op.attrs["axis"]) - desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:] - - # Construct 1 shape tensor to be used by all inserted reshape ops - new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, desired_shape) - - for idx, inp in enumerate(op.inputs): - reshape_out = inp.clone("_reshaped") - reshape_out.set_all_shapes(desired_shape) - - reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx)) - reshape_op.attrs["new_shape"] = desired_shape - reshape_op.inputs = [inp, new_shape_tens] - reshape_op.set_output_tensor(reshape_out) - reshape_op.set_ifm_ofm_shapes() - DebugDatabase.add_optimised(op, reshape_op) - - op.inputs[idx] = reshape_out - - op.type = Op.PackReshaped - + n = op.ofm_shapes[0].batch + h, w = batching_split.get(n, (1, n)) + op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth]) + op.ofm.avoid_NHCWB16 = True return op @@ -441,12 +381,19 @@ def unfuse_activation_function(op, arch, nng): return op -def fixup_stridedslice_output(tens, arch, nng): - op = tens.ops[0] - if op.run_on_npu and op.type == Op.StridedSlice: - reshape_input_shape = tens.shape - new_axis_mask = op.attrs["new_axis_mask"] - shrink_axis_mask = op.attrs["shrink_axis_mask"] +def rewrite_stridedslice_output(op, arch, nng): + if not op.run_on_npu or op.type != Op.StridedSlice: + return op + + new_axis_mask = op.attrs["new_axis_mask"] + shrink_axis_mask = op.attrs["shrink_axis_mask"] + + if shrink_axis_mask == 0 and new_axis_mask == 0: + return op + + axis_4D = [0] * len(op.outputs) + for idx, out_tens in enumerate(op.outputs): + output_shape = list(out_tens.shape) if shrink_axis_mask != 0: n = 0 @@ -456,10 +403,16 @@ def fixup_stridedslice_output(tens, arch, nng): n += 1 shrink_axis_mask &= shrink_axis_mask - 1 axis = int(math.log2(prev_mask - shrink_axis_mask)) - reshape_input_shape = reshape_input_shape[:axis] + [1] + reshape_input_shape[axis:] + output_shape = output_shape[:axis] + [1] + output_shape[axis:] - assert len(tens.shape) == (len(op.inputs[0].shape) - n) + assert len(out_tens.shape) == (len(op.inputs[0].shape) - n) op.attrs["shrink_axis_mask"] = 0 + if axis >= 0: + axis_4D[idx] = axis + (4 - len(output_shape)) + else: + axis_4D[idx] = axis + op.ofm_shapes[idx] = Shape4D(output_shape) + elif new_axis_mask != 0: n = 0 axis = 0 @@ -468,77 +421,62 @@ def fixup_stridedslice_output(tens, arch, nng): n += 1 new_axis_mask &= new_axis_mask - 1 axis = int(math.log2(prev_mask - new_axis_mask)) - reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :] + output_shape = output_shape[:axis] + output_shape[(axis + 1) :] new_axis_mask >>= 1 - assert len(tens.shape) == (len(op.inputs[0].shape) + n) + assert len(out_tens.shape) == (len(op.inputs[0].shape) + n) op.attrs["new_axis_mask"] = 0 - else: - # Equal Rank StridedSlice, no need to insert reshape - return tens - - # Construct 1 shape tensor to be used by all inserted reshape ops - new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape) - - for idx, out_tens in enumerate(op.outputs): - op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape) - reshape_in = out_tens.clone("_reshaped") - reshape_in.set_all_shapes(reshape_input_shape) - reshape_in.ops = [op] - - reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx)) - reshape_op.attrs["new_shape"] = reshape_input_shape - reshape_op.inputs = [reshape_in, new_shape_tens] - reshape_op.set_output_tensor(out_tens) - reshape_op.set_ifm_ofm_shapes() + if axis >= 0: + axis_4D[idx] = axis + (4 - len(output_shape)) + else: + axis_4D[idx] = axis + op.ofm_shapes[idx] = Shape4D(output_shape) - op.outputs[idx] = reshape_in + if op.ofm_shapes[idx] != Shape4D(out_tens.shape): + out_tens.avoid_NHCWB16 = True - return tens + op.attrs["split_axis_4D"] = axis_4D + return op -def fixup_unpack_output(tens, arch, nng): - op = tens.ops[0] +def rewrite_unpack_output(op, arch, nng): + tens = op.outputs[0] if op.run_on_npu and op.type == Op.Unpack: # Unpack is also referred to as Unstack - # Requires the rewrite_split function to be called on the op afterwards axis = int(op.attrs["axis"]) op.type = Op.UnpackReshaped - reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:] + desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:] - # Construct 1 shape tensor to be used by all inserted reshape ops - new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape) + if axis >= 0: + axis_4D = axis + (4 - len(desired_output_shape)) + else: + axis_4D = axis + axis_4D_list = [0] * len(op.outputs) for idx, out_tens in enumerate(op.outputs): - reshape_in = out_tens.clone("_reshaped") - reshape_in.set_all_shapes(reshape_input_shape) - reshape_in.ops = [op] - - reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx)) - reshape_op.attrs["new_shape"] = reshape_input_shape - reshape_op.inputs = [reshape_in, new_shape_tens] - reshape_op.set_output_tensor(out_tens) - reshape_op.set_ifm_ofm_shapes() - DebugDatabase.add_optimised(op, reshape_op) - - op.outputs[idx] = reshape_in - return tens + op.ofm_shapes[idx] = Shape4D(desired_output_shape) + axis_4D_list[idx] = axis_4D + if op.ofm_shapes[idx] != Shape4D(out_tens.shape): + out_tens.avoid_NHCWB16 = True + + op.attrs["split_axis_4D"] = axis_4D_list + return op def add_padding_fields(op, arch, nng): if op.run_on_npu: if "padding" in op.attrs: + input_shape = op.ifm_shapes[0] + output_shape = op.ofm_shapes[0] if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op(): kernel_size = op.inputs[1].shape[:2] - input_shape = op.inputs[0].shape elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum: kernel_size = op.attrs["ksize"][1:3] - input_shape = op.inputs[0].shape else: raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}") if op.type == Op.Conv2DBackpropInputSwitchedBias: - upscaling_factor = op.outputs[0].shape[1] // input_shape[1] + upscaling_factor = output_shape.height // input_shape.height padding, skirt = calc_upscaled_padding_and_skirt( op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor ) @@ -582,10 +520,10 @@ def convert_depthwise_to_conv(op, arch, nng): # switch of the operator type (and weight order) if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1): - ifm_tensor = op.inputs[0] + ifm_shape = op.ifm_shapes[0] weight_tensor = op.inputs[1] - ofm_tensor = op.outputs[0] - if (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"]): + ofm_shape = op.ofm_shapes[0] + if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]): # Change op type to Conv2d op.type = Op.Conv2DBias del op.attrs["channel_multiplier"] @@ -596,7 +534,7 @@ def convert_depthwise_to_conv(op, arch, nng): else: raise UnsupportedFeatureError( f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},", - f" ifm channels = {ifm_tensor.shape[3]}, ofm channels = {ofm_tensor.shape[3]}", + f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}", ) DebugDatabase.add_optimised(op, op) return op @@ -620,17 +558,15 @@ def optimise_strided_conv(op, arch, nng): op.type == Op.Conv2DBias and op.op_index == 0 and stride_x == 2 - and len(ifm_tensor.shape) == 4 - and ifm_tensor.shape[3] <= 4 - and ifm_tensor.shape[2] % 2 == 0 + and op.ifm_shapes[0].depth <= 4 + and op.ifm_shapes[0].width % 2 == 0 and weight_tensor is not None and weight_tensor.shape[1] >= 2 ): + ifm_shape = op.ifm_shapes[0] # IFM - ifm_reshaped = create_reshape_tensor( - ifm_tensor, [ifm_tensor.shape[0], ifm_tensor.shape[1], ifm_tensor.shape[2] // 2, ifm_tensor.shape[3] * 2] - ) - op.set_input_tensor(ifm_reshaped, 0) + op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2]) + op.ifm.avoid_NHCWB16 = True # Weights weight_shape = weight_tensor.shape @@ -657,8 +593,6 @@ def optimise_strided_conv(op, arch, nng): stride_x = 1 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)}) - op.set_ifm_ofm_shapes() - return op @@ -683,27 +617,6 @@ def convert_conv_to_fc(op, arch, nng): weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1)) weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) - # The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it - # back to 4D afterwards as the next layer is expecting that shape - orig_ofm_tensor = op.outputs[0] - # Reshape this ops output to be 2D: {(N*H*W), C} (We know N H and W are all 1 so this becomes {1, C}) - fc_ofm_tensor = orig_ofm_tensor.clone("_fc") - fc_ofm_tensor.set_all_shapes([1, fc_ofm_tensor.shape[-1]]) - fc_ofm_tensor.ops = [op] - # Add a reshape after the new OFM to convert it back to the original 4D shape - reshape_name = op.name + "_reshape" - new_shape_tens = create_const_tensor(reshape_name + "_shape", [1], DataType.int32, orig_ofm_tensor.shape) - reshape_op = Operation(Op.Reshape, reshape_name) - reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape - reshape_op.inputs = [fc_ofm_tensor, new_shape_tens] - reshape_op.set_output_tensor(orig_ofm_tensor) - reshape_op.set_ifm_ofm_shapes() - - # Replace this ops OFM to point to the 2D tensor - op.outputs[0] = fc_ofm_tensor - op.set_ifm_ofm_shapes() - # Record optimisation in debug database - DebugDatabase.add_optimised(op, reshape_op) DebugDatabase.add_optimised(op, op) return op @@ -722,14 +635,6 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng): # Tidy up and assign the ifm and ofm to the new op ifm.consumer_list.remove(op) - # if not 4d, reshape ifm/ofm - if len(ifm.shape) < 4: - ifm_shaped = create_reshape_tensor(ifm, full_shape(4, ifm.shape, 1)) - ifm = ifm_shaped - if len(ofm.shape) < 4: - ofm_shaped = create_reshape_tensor(ofm, full_shape(4, ofm.shape, 1), False) - ofm = ofm_shaped - relu_fused_op.add_input_tensor(ifm) relu_fused_op.set_output_tensor(ofm) relu_fused_op.set_ifm_ofm_shapes() @@ -737,6 +642,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng): return op +# TODO remove if mem only ops can all be removed # Reorder activation op if it's after the memory only operations def fixup_act_reorder(op, arch, nng): if op.type.is_relu_op() or op.type in (Op.Sigmoid, Op.Tanh): @@ -752,8 +658,8 @@ def fixup_act_reorder(op, arch, nng): act_op_out = act_op.inputs[0].clone("_acted") act_op_out.quantization = op.outputs[0].quantization.clone() act_op.set_output_tensor(act_op_out) - act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape) - act_op.ofm_shapes[0] = Shape4D(act_op_out.shape) + act_op.ofm_shapes[0] = act_op.ifm_shapes[0].clone() + act_op.ifm_shapes[0] = prep_op.ifm_shapes[0].clone() # Update the consumer list act_op_out.consumer_list = op.outputs[0].consumer_list.copy() @@ -1078,39 +984,94 @@ def convert_tanh_sigmoid_to_lut(op, arch, nng): return op -def remove_unwanted_reshapes(op, arch, nng): - # Try to remove reshapes enclosing ElementWise operator with only one non-constant input - if not op.run_on_npu or not op.type.is_elementwise_op(): - return op +def remove_reshapes(op, arch): + if op.run_on_npu and op.type == Op.Reshape: + ofm = op.ofm + ifm = op.ifm - # Check if the ElementWise operator only have one non-constant input - non_const_tens = [x for x in op.inputs if x.ops[0].type != Op.Const] - if len(non_const_tens) != 1: - return op - ifm = non_const_tens[0] + # Check if quantization is the same in the input and output for the reshape ops + if not check_quantized_tens_scaling_equal(ifm, ofm): + # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors. + # In order to remove this reshape either quantization properties need to be moved to Operator, + # or the reshape need to be replace with a NOP. + return + + # Check if ifm is a sg input + if ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const): + # put the reshape on CPU + op.run_on_npu = False + return + + # Check if Reshape ifm/ofm are network ifm/ofm + ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list) + ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list) + + if ifm_is_sg_ofm and ofm_is_sg_ofm: + # Both ifm and ofm are sg outputs,add reshape to the ifm and put it on CPU + ifm_cons_list_copy = ifm.consumer_list.copy() + ifm_ops_copy = ifm.ops.copy() + for ifm_cons in ifm_cons_list_copy: + if ifm_cons is None: + # Create a reshape op with ifm as output + name = ifm.name + "_cpu_reshape" + reshape_ifm = ifm.clone() + reshape_op = Operation(Op.Reshape, name) + reshape_op.attrs["new_shape"] = ifm.shape + reshape_op.add_input_tensor(reshape_ifm) + reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, ifm.shape)) + reshape_op.set_output_tensor(ifm) + reshape_op.set_ifm_ofm_shapes() + reshape_op.run_on_npu = False + reshape_op.ofm.ops = [reshape_op] + reshape_op.ofm.consumer_list = [None] + + # Set reshape_ifm producers + for prev_op in ifm_ops_copy: + prev_op.outputs = [reshape_ifm] + reshape_ifm.ops.append(prev_op) + + # Set reshape_ifm consumers + for ifm_cons in ifm_cons_list_copy: + if ifm_cons is not None: + for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs): + if cons_ifm == ifm: + ifm_cons.set_input_tensor(reshape_ifm, ifm_idx) + + ifm = reshape_ifm + break + ifm_is_sg_ofm = False + + if ofm_is_sg_ofm: + # Bypassed by replacing ifm with ofm + ofm.ops = [] + for prev_op in ifm.ops: + prev_op.outputs = [ofm] + ofm.ops.append(prev_op) + + # All ifm consumers need to use ofm as input + for ifm_cons in ifm.consumer_list: + for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs): + if cons_ifm == ifm: + ifm_cons.set_input_tensor(ofm, ifm_idx) + if op.ifm_shapes[0] != op.ofm_shapes[0]: + ofm.avoid_NHCWB16 = True + else: + # Bypassed Reshape by replacing ofm with ifm + for cons in ofm.consumer_list: + for ifm_idx, cons_ifm in enumerate(cons.inputs): + if cons_ifm == ofm: + cons.set_input_tensor(ifm, ifm_idx) + if op.ifm_shapes[0] != op.ofm_shapes[0]: + ifm.avoid_NHCWB16 = True - # Check if operation is enclosed by Reshapes that can be removed - ofm = op.outputs[0] - prev_op = ifm.ops[0] - if ( - len(ifm.consumer_list) == 1 - and prev_op.type == Op.Reshape - and len(ofm.consumer_list) == 1 - and ofm.consumer_list[0].type == Op.Reshape - ): - # Operation is enclosed by reshapes, check if they can be removed - prev_op_ifm, prev_op_ofm = prev_op.get_ifm_ofm() - cons_op = ofm.consumer_list[0] - cons_op_ifm = ofm - cons_op_ofm = cons_op.outputs[0] - if len(prev_op_ifm.shape) == len(cons_op_ofm.shape): - # Check if quantization is the same in the input and output for the reshape ops - if check_quantized_tens_scaling_equal(prev_op_ifm, prev_op_ofm) and check_quantized_tens_scaling_equal( - cons_op_ifm, cons_op_ofm - ): - op.set_input_tensor(prev_op_ifm, 0) - op.set_output_tensor(cons_op_ofm) - return op + +def check_reshapes(op, arch): + if op.run_on_npu and op.type == Op.Reshape: + ofm = op.ofm + + if check_quantized_tens_scaling_equal(op.ifm, ofm): + # Reshape should have been removed + raise VelaError(f"Reshape op {op} expected to have been removed, still remains") def fuse_activation_function_with_prev(op, arch, nng): @@ -1174,13 +1135,19 @@ def optimise_pad(op, arch, nng): def add_attrs_to_resizebilinear(op, arch, nng): if op.type == Op.ResizeBilinear and op.run_on_npu: input_tensor = op.inputs[0] - upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2] - out_shape = op.outputs[0].shape[1:3] - if not op.attrs["align_corners"] and out_shape == upscaled_shape: + input_shape = op.ifm_shapes[0] + upscaled_height = input_shape.height * 2 + upscaled_width = input_shape.width * 2 + out_shape = op.ofm_shapes[0] + if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width: # this means the output is supposed to be a x2 upscale, # so we need to do SAME padding op.attrs["padding"] = Padding.SAME - elif op.attrs["align_corners"] and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]: + elif ( + op.attrs["align_corners"] + and out_shape.height == (upscaled_height - 1) + and out_shape.width == (upscaled_width - 1) + ): # here we can just run the avg pool without padding and # produce a (M * 2 - 1, N * 2 - 1) sized output op.attrs["padding"] = Padding.VALID @@ -1229,26 +1196,52 @@ def optimise_graph_a(nng, arch, verbose_graph=False): nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, ) + # Handle Concat Ops + for idx, sg in enumerate(nng.subgraphs): + # rewrite graph pass + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [rewrite_concat_ops], rewrite_unsupported=False, + ) + + # Handle Split Ops + for idx, sg in enumerate(nng.subgraphs): + # rewrite graph pass + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, + sg, + arch, + [], + [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity], + rewrite_unsupported=False, + ) + + for idx, sg in enumerate(nng.subgraphs): + # rewrite graph pass + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False, + ) + + # Removal of reshapes + for sg in nng.subgraphs: + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) + sg.refresh_after_modification() + op_rewrite_list = [ set_tensor_equivalence, convert_depthwise_to_conv, convert_conv_to_fc, convert_softmax, optimise_strided_conv, - fixup_fully_connected_input, convert_batched_fc_shape, - fixup_pack_input, unfuse_activation_function, fixup_conv2d_backprop, fixup_relus_with_differing_ifm_ofm_scaling, fixup_act_reorder, - fixup_elementwise_with_scalars, + fixup_elementwise_with_scalars, # TODO Move to early stage? reorder_depthwise_weights, fixup_resizebilinear, fixup_bias_tensors, - convert_nop_split_to_identity, convert_mul_max_to_abs_or_lrelu, - remove_unwanted_reshapes, convert_lrelu, convert_tanh_sigmoid_to_lut, ] @@ -1269,24 +1262,9 @@ def optimise_graph_a(nng, arch, verbose_graph=False): [fuse_activation_function_with_prev, optimise_pad, add_padding_fields], ) - # Post-optimisation operator debug tracing + # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph for sg in nng.subgraphs: - rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [_record_optimised]) - - if verbose_graph: - nng.print_graph() - return nng - - -def optimise_graph_b(nng, arch, verbose_graph=False): - if verbose_graph: - nng.print_graph() - - for idx, sg in enumerate(nng.subgraphs): - # combined rewrite graph pass - nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [], - ) + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised]) if verbose_graph: nng.print_graph() diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 60e62aa6..e514e76c 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -374,13 +374,13 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs): if cmd.is_npu_pass_command(): if cmd.is_first: ifm_read = cmd.ifm_tensor.address_offset_for_coordinate( - cmd.ifm_box.start_coord, cmd.ps.ifm_shapes[0].as_list(), is_top_box=False + cmd.ifm_box.start_coord, cmd.ps.ifm_shapes[0], is_top_box=False ) if ifm_read is None: return 0 if cmd.is_last: write_offset = cmd.ofm_tensor.address_offset_for_coordinate( - cmd.ofm_box.end_coord, cmd.ps.ofm_shapes[0].as_list(), is_top_box=True + cmd.ofm_box.end_coord, cmd.ps.ofm_shapes[0], is_top_box=True ) if write_offset is None: return 0 @@ -393,7 +393,7 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs): if cmd.is_first: ifm_read = cmd.ifm_tensor.address_offset_for_coordinate( - cmd.ifm_box.end_coord, cmd.ps.ifm_shapes[0].as_list(), is_top_box=True + cmd.ifm_box.end_coord, cmd.ps.ifm_shapes[0], is_top_box=True ) min_overlap = max(min_overlap, 0) diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 8e4d33a5..31434835 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -43,7 +43,6 @@ from .api import NpuRoundingMode from .api import NpuShape3D from .api import NpuTileBox from .architecture_features import ArchitectureFeatures -from .architecture_features import Block from .data_type import DataType from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError @@ -152,7 +151,7 @@ def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding: # because of activation function needed to be fused. if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0: left = 0 - if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < Block.from_shape(cmd.ifm_tensor.shape).width: + if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width: right = 0 return NpuPadding(top=top, left=left, bottom=bottom, right=right) @@ -233,7 +232,7 @@ def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]: return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point) -def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: Shape4D) -> NpuFeatureMap: +def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap: """Creates feature map with common fields populated""" fm = NpuFeatureMap() fm.region = get_region(tens, arch) @@ -244,14 +243,16 @@ def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_sh fm.layout = NpuLayout.NHCWB16 else: assert 0, "Incorrect tensor format" - height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord, fm_shape) + height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer( + box.start_coord, box.end_coord, op_shape4D + ) for idx, addr in enumerate(addresses): if addr is None: addresses[idx] = 0 fm.tiles = NpuTileBox( height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses] ) - strides = tens.get_strides() + strides = tens.get_strides(shape4D=op_shape4D) fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1])) return fm @@ -325,7 +326,7 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit op = ps.primary_op ifm_height = cmd.ifm_box.get_block().height - ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width + ifm_width = cmd.ps.ifm_shapes[0].width ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box) npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0]) @@ -401,7 +402,9 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu npu_op = NpuElementWiseOperation(elemwise_op) if elemwise_op not in UNARY_ELEMWISE_OPS: - if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape): + ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list() + ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list() + if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape): # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box @@ -416,7 +419,7 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0) else: ifm2_blk = cmd.ifm2_box.get_block() - ifm2_width = Block.from_shape(cmd.ifm2_tensor.shape).width + ifm2_width = ps.ifm_shapes[1].width npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth) set_common_op_fields(npu_op, cmd, arch) # Check if output scale needs to be overridden diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py index c2418d73..3acd5e6c 100644 --- a/ethosu/vela/npu_performance.py +++ b/ethosu/vela/npu_performance.py @@ -117,15 +117,21 @@ def get_ifm_block_depth(npu_block_type, ifm_depth, ifm_elemwidth, block_traversa return min(ifm_depth, ifm_blk_depth) -def get_minimal_cmd_cycles(arch, ifm_tensor, ofm_tensor, ifm_blk: Block, ofm_blk: Block, output_cycles, dpu_cycles=0): +def get_minimal_cmd_cycles( + arch, ifm_tensor, ofm_tensor, ifm_blk: Block, ofm_blk: Block, output_cycles, ifm_shape4D, ofm_shape4D, dpu_cycles=0 +): ifm_tens_blk = Tensor((1, ifm_blk.height, ifm_blk.width, ifm_blk.depth), ifm_tensor.dtype, "ifm_blk") ofm_tens_blk = Tensor((1, ofm_blk.height, ofm_blk.width, ofm_blk.depth), ofm_tensor.dtype, "ofm_blk") cycles_ifm_blk = ( - estimate_memory_transfer_efficiency(arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk) + estimate_memory_transfer_efficiency( + arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk, shape4D=ifm_shape4D + ) / arch.memory_bandwidths_per_cycle[ifm_tensor.mem_area] ) cycles_ofm_blk = ( - estimate_memory_transfer_efficiency(arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk) + estimate_memory_transfer_efficiency( + arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk, shape4D=ofm_shape4D + ) / arch.memory_bandwidths_per_cycle[ofm_tensor.mem_area] ) return ( @@ -204,7 +210,14 @@ def estimate_output_cycles( if primary_op.type.is_elementwise_op() and block_config is not None: num_elems_blk = block_config.width * block_config.height * block_config.depth cycle_cmd = get_minimal_cmd_cycles( - arch, ifm_tensor, ofm_tensor, block_config, block_config, num_elems_blk * cycle_per_elem + arch, + ifm_tensor, + ofm_tensor, + block_config, + block_config, + num_elems_blk * cycle_per_elem, + primary_op.ifm_shapes[0], + primary_op.ofm_shapes[0], ) cycle_per_elem = max(cycle_per_elem, cycle_cmd / num_elems_blk) @@ -343,7 +356,15 @@ def estimate_conv_pooling_cycles( cycles_output_blk = max(cycles_output_blk, cycles_bias_blk) cycles_cmd = get_minimal_cmd_cycles( - arch, ifm_tensor, ofm_tensor, ifm_block, ofm_block, cycles_dpu_blk, cycles_output_blk + arch, + ifm_tensor, + ofm_tensor, + ifm_block, + ofm_block, + cycles_dpu_blk, + ifm_tens_shape, + ofm_tens_shape, + cycles_output_blk, ) cycles_dpu_blk = max(cycles_dpu_blk, cycles_cmd) cycles_output_blk = max(cycles_output_blk, cycles_cmd) @@ -356,7 +377,9 @@ def estimate_conv_pooling_cycles( return total_cycles -def estimate_memory_transfer_efficiency(arch, mem_area, direction, tensor, block_size: Block, replace_bw=None): +def estimate_memory_transfer_efficiency( + arch, mem_area, direction, tensor, block_size: Block, replace_bw=None, shape4D=None +): if tensor.format not in (TensorFormat.NHWC, TensorFormat.NHCWB16): return tensor.bandwidth() if replace_bw is None else replace_bw @@ -368,9 +391,10 @@ def estimate_memory_transfer_efficiency(arch, mem_area, direction, tensor, block tens = tensor.clone() if not tens.avoid_NHCWB16: tens.set_format(TensorFormat.NHCWB16, arch) + strides = tens.get_strides(shape4D=shape4D) if tens.format == TensorFormat.NHCWB16: - if tens.get_strides()[1] == block_size.depth: + if strides[1] == block_size.depth: burst_len = elem_size * block_size.depth * block_size.width elif is_ifm: burst_len = 16 * elem_size * block_size.width @@ -379,12 +403,12 @@ def estimate_memory_transfer_efficiency(arch, mem_area, direction, tensor, block else: assert tens.format == TensorFormat.NHWC if is_ifm: - if tens.get_strides()[3] == block_size.depth: + if strides[3] == block_size.depth: burst_len = elem_size * block_size.depth * block_size.width else: burst_len = elem_size * block_size.depth else: - if block_size.depth <= 16 and tens.get_strides()[3] == block_size.depth: + if block_size.depth <= 16 and strides[3] == block_size.depth: burst_len = elem_size * block_size.depth * block_size.width else: burst_len = min(64, 16 * elem_size * arch.ncores, block_size.depth * elem_size) @@ -585,12 +609,12 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None, scaled_bws[arch.fast_storage_mem_area][tens.purpose][ BandwidthDirection.Write ] += estimate_memory_transfer_efficiency( - arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block + arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block, shape4D=ps.ofm_shapes[0], ) else: bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth() scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_transfer_efficiency( - arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block + arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block, shape4D=ps.ofm_shapes[0] ) for tens in ps.intermediates: @@ -612,8 +636,16 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None, bw = tens.bandwidth() bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw + + op_shape = None + if ps.placement == PassPlacement.Npu and primary_op: + if tens == ps.ifm_tensor: + op_shape = ps.ifm_shapes[0] + elif tens == ps.ifm2_tensor: + op_shape = ps.ifm_shapes[1] + scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_transfer_efficiency( - arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, bw + arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, bw, op_shape ) # quick build access counts for only current pass, even though these aren't the final numbers diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 844f2985..342efd9d 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -629,7 +629,6 @@ class Operation: elif self.type == Op.StridedSlice: input_tens, begin_tens, end_tens, strides_tens = self.inputs outputs = self.outputs - out_tens = outputs[0] # Extract masks begin_mask = self.attrs["begin_mask"] @@ -641,7 +640,6 @@ class Operation: # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation # may have the attribute modified and handled in the graph optimization phase. assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0 - assert len(input_tens.shape) == len(out_tens.shape) offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True) offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False) elif self.type == Op.UnpackReshaped: diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py index 7015b799..e7e4bbbc 100644 --- a/ethosu/vela/operation_util.py +++ b/ethosu/vela/operation_util.py @@ -24,7 +24,7 @@ from .operation import ActivationFunction from .operation import Op from .operation import Operation from .operation import Padding -from .tensor import create_reshape_tensor +from .shape4d import Shape4D from .tensor import QuantizationParameters from .tensor import Tensor @@ -44,12 +44,17 @@ def create_avgpool_nop(name: str) -> Operation: def create_depthwise_maxpool( - name: str, ifm: Tensor, quantization: QuantizationParameters, activation: Optional[ActivationFunction] = None + name: str, + ifm: Tensor, + inp_shape: Shape4D, + quantization: QuantizationParameters, + activation: Optional[ActivationFunction] = None, ) -> Operation: op = Operation(Op.MaxPool, name) - height = ifm.shape[1] * ifm.shape[2] - width = ifm.shape[3] - ifm_shape = [1, height, width, 1] + height = inp_shape.height * inp_shape.width + width = inp_shape.depth + ifm_shape = Shape4D([1, height, width, 1]) + op.attrs["padding"] = Padding.VALID op.attrs["stride_w"] = 1 op.attrs["stride_h"] = 1 @@ -58,11 +63,14 @@ def create_depthwise_maxpool( op.attrs["strides"] = [1, op.attrs["stride_h"], op.attrs["stride_w"], 1] op.attrs["ksize"] = [1, op.attrs["filter_height"], op.attrs["filter_width"], 1] op.activation = activation - op.inputs = [create_reshape_tensor(ifm, ifm_shape)] + op.inputs = [ifm] ofm = Tensor([1, height, 1, 1], ifm.dtype, op.name + "_tens0") ofm.quantization = quantization op.set_output_tensor(ofm) - op.set_ifm_ofm_shapes() + op.ifm_shapes.append(ifm_shape) + op.ofm_shapes.append(Shape4D(ofm.shape)) + op.ifm.avoid_NHCWB16 = True + op.ofm.avoid_NHCWB16 = True return op @@ -95,8 +103,12 @@ def create_add( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, + ifm2_shape: Optional[Shape4D] = None, ) -> Operation: - return create_binary_elementwise(Op.Add, name, ifm, ifm2, quantization, activation, dtype, attrs) + return create_binary_elementwise( + Op.Add, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape + ) def create_rescale_add( @@ -108,8 +120,12 @@ def create_rescale_add( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, + ifm2_shape: Optional[Shape4D] = None, ) -> Operation: - op = create_binary_elementwise(Op.RescaleAdd, name, ifm, ifm2, quantization, activation, dtype, attrs) + op = create_binary_elementwise( + Op.RescaleAdd, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape + ) op.rescale = rescale return op @@ -121,8 +137,9 @@ def create_clz( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, ) -> Operation: - return create_unary_elementwise(Op.CLZ, name, ifm, quantization, activation, dtype, attrs) + return create_unary_elementwise(Op.CLZ, name, ifm, quantization, activation, dtype, attrs, ifm_shape) def create_mul( @@ -133,8 +150,12 @@ def create_mul( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, + ifm2_shape: Optional[Shape4D] = None, ) -> Operation: - return create_binary_elementwise(Op.Mul, name, ifm, ifm2, quantization, activation, dtype, attrs) + return create_binary_elementwise( + Op.Mul, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape + ) def create_shl( @@ -145,8 +166,12 @@ def create_shl( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, + ifm2_shape: Optional[Shape4D] = None, ) -> Operation: - return create_binary_elementwise(Op.SHL, name, ifm, ifm2, quantization, activation, dtype, attrs) + return create_binary_elementwise( + Op.SHL, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape + ) def create_shr( @@ -157,8 +182,12 @@ def create_shr( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, + ifm2_shape: Optional[Shape4D] = None, ) -> Operation: - return create_binary_elementwise(Op.SHR, name, ifm, ifm2, quantization, activation, dtype, attrs) + return create_binary_elementwise( + Op.SHR, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape + ) def create_sub( @@ -169,8 +198,12 @@ def create_sub( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, + ifm2_shape: Optional[Shape4D] = None, ) -> Operation: - return create_binary_elementwise(Op.Sub, name, ifm, ifm2, quantization, activation, dtype, attrs) + return create_binary_elementwise( + Op.Sub, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape + ) def create_unary_elementwise( @@ -181,8 +214,9 @@ def create_unary_elementwise( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, ) -> Operation: - return create_binary_elementwise(op_type, name, ifm, None, quantization, activation, dtype, attrs) + return create_binary_elementwise(op_type, name, ifm, None, quantization, activation, dtype, attrs, ifm_shape, None) def create_binary_elementwise( @@ -194,19 +228,34 @@ def create_binary_elementwise( activation: Optional[ActivationFunction] = None, dtype: Optional[DataType] = None, attrs: Optional[dict] = None, + ifm_shape: Optional[Shape4D] = None, + ifm2_shape: Optional[Shape4D] = None, ) -> Operation: + if ifm_shape is None: + ifm_shape = Shape4D(ifm.shape) op = Operation(op_type, name) op.add_input_tensor(ifm) + op.ifm_shapes.append(ifm_shape) if ifm2: op.add_input_tensor(ifm2) + if ifm2_shape is None: + ifm2_shape = Shape4D(ifm2.shape) + op.ifm_shapes.append(ifm2_shape) op.activation = activation if not dtype: dtype = ifm.dtype if attrs: op.attrs.update(attrs) - ofm_shape = ifm.shape if ifm2 is None or ifm_ifm2_correct_order(ifm.shape, ifm2.shape) else ifm2.shape - ofm = Tensor(ofm_shape, dtype, f"{op.name}_tens0") + + if ifm2 is None: + ofm_shape = ifm_shape + else: + in_shape = [] if ifm.shape == [] else ifm_shape.as_list() + in2_shape = [] if ifm2.shape == [] else ifm2_shape.as_list() + ofm_shape = ifm_shape if ifm_ifm2_correct_order(in_shape, in2_shape) else ifm2_shape + + ofm = Tensor(ofm_shape.as_list(), dtype, f"{op.name}_tens0") ofm.quantization = quantization op.set_output_tensor(ofm) - op.set_ifm_ofm_shapes() + op.ofm_shapes.append(ofm_shape) return op diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index ee0d7128..a95e3839 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -150,7 +150,7 @@ test_sequence = [ # ops_set npu_pre_ops, # incompatible_pack_flags - PassFlags.Cpu | PassFlags.MemoryOnly, + PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.ElementWise, # flags_to_set PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise, # flags_to_clear @@ -458,11 +458,11 @@ def pack_into_passes(nng, arch, verbose_packing=False): avgpool_out = inp.clone("_avgpooled") avgpool_out.consumer_list.append(op) avgpool_op.set_output_tensor(avgpool_out) - avgpool_op.set_ifm_ofm_shapes() + avgpool_op.ifm_shapes = op.ifm_shapes.copy() + avgpool_op.ofm_shapes = op.ofm_shapes.copy() op.inputs[0] = avgpool_out op_list.insert(0, avgpool_op) - op.set_ifm_ofm_shapes() DebugDatabase.add_optimised(op, avgpool_op) return avgpool_op diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py index a1b4feaa..8981e20b 100644 --- a/ethosu/vela/shape4d.py +++ b/ethosu/vela/shape4d.py @@ -75,3 +75,6 @@ class Shape4D: def as_list(self): return list(self._shape4D) + + def get_hw_as_list(self): + return list([self.height, self.width]) diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index 8a1770e1..656a7e69 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -39,8 +39,8 @@ from .operation_util import create_rescale_add from .operation_util import create_shl from .operation_util import create_shr from .operation_util import create_sub +from .shape4d import Shape4D from .tensor import create_const_tensor -from .tensor import create_reshape_tensor from .tensor import TensorPurpose @@ -214,12 +214,13 @@ class SoftMax: ofm = self.op.outputs[0] # Reshape ifm/ofm (if needed) - full_shape = self.op.ifm_shapes[0].as_list() - if full_shape[0] > 1: - full_shape[1] *= full_shape[0] - full_shape[0] = 1 - ifm = create_reshape_tensor(ifm, full_shape) - ofm = create_reshape_tensor(ofm, full_shape, False) + ifm_shape = self.op.ifm_shapes[0] + if ifm_shape.batch > 1: + ifm_shape.height = ifm_shape.batch * ifm_shape.height + ifm_shape.batch = 1 + self.op.ifm.avoid_NHCWB16 = True + self.op.ofm_shapes[0] = ifm_shape.clone() + self.op.ofm.avoid_NHCWB16 = True if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype: return self.get_graph_8bit(ifm, ofm) @@ -233,7 +234,6 @@ class SoftMax: exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32) no_scale_quant = ifm.quantization.clone() no_scale_quant.scale_f32 = None - no_scale_quant.zero_point = 0 activation = ActivationFunction(Op.Clip) activation.min = ifm.quantization.quant_min activation.max = ifm.quantization.quant_max @@ -245,7 +245,6 @@ class SoftMax: one_scale_quant.zero_point = 0 two_scale_quant = one_scale_quant.clone() two_scale_quant.scale_f32 = 2.0 - ifm.quantization.zero_point = 0 pass_number = 0 def add_op_get_ofm(op): @@ -255,13 +254,25 @@ class SoftMax: return op.ofm # PASS 0 - Depthwise Maxpool - ifm_max = add_op_get_ofm(create_depthwise_maxpool(f"{self.op.name}_maxpool{pass_number}", ifm, no_scale_quant)) + ifm_shape = self.op.ifm_shapes[0] + ifm_max = add_op_get_ofm( + create_depthwise_maxpool(f"{self.op.name}_maxpool{pass_number}", ifm, ifm_shape, no_scale_quant) + ) # PASS 1 - Sub+LUT(exp) sub_op_quantization = one_scale_quant.clone() sub_op_quantization.zero_point = 127 - ifm_max = create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1]) - sub_op = create_sub(f"{self.op.name}_sub{pass_number}", ifm, ifm_max, sub_op_quantization, dtype=DataType.int32) + ifm_max_shape = Shape4D([1, ifm_shape.height, ifm_shape.width, 1]) + ifm_max.avoid_NHCWB16 = True + sub_op = create_sub( + f"{self.op.name}_sub{pass_number}", + ifm, + ifm_max, + sub_op_quantization, + dtype=DataType.int32, + ifm_shape=ifm_shape, + ifm2_shape=ifm_max_shape, + ) sub_op.set_activation_lut( create_const_tensor( f"{sub_op.name}_exp_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT @@ -415,7 +426,9 @@ class SoftMax: shr30_op.add_input_tensor(scaled_exp) shr30_op.add_input_tensor(right_shift) shr30_op.set_output_tensor(ofm) - shr30_op.set_ifm_ofm_shapes() + shr30_op.ifm_shapes.append(Shape4D(scaled_exp.shape)) + shr30_op.ifm_shapes.append(Shape4D(right_shift.shape)) + shr30_op.ofm_shapes.append(Shape4D(scaled_exp.shape)) DebugDatabase.add_optimised(self.op, shr30_op) return shr30_op @@ -432,12 +445,24 @@ class SoftMax: return op.ofm # PASS 0 - Depthwise Maxpool - ifm_max = add_op_get_ofm(create_depthwise_maxpool(f"{self.op.name}_maxpool{pass_number}", ifm, no_scale_quant)) + ifm_shape = self.op.ifm_shapes[0] + ifm_max = add_op_get_ofm( + create_depthwise_maxpool(f"{self.op.name}_maxpool{pass_number}", ifm, ifm_shape, no_scale_quant) + ) # PASS 1 - Sub - ifm_max = create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1]) + ifm_max_shape = Shape4D([1, ifm_shape.height, ifm_shape.width, 1]) + ifm_max.avoid_NHCWB16 = True sub1_ofm = add_op_get_ofm( - create_sub(f"{self.op.name}_sub{pass_number}", ifm, ifm_max, ifm.quantization.clone(), dtype=DataType.int32) + create_sub( + f"{self.op.name}_sub{pass_number}", + ifm, + ifm_max, + ifm.quantization.clone(), + dtype=DataType.int32, + ifm_shape=ifm_shape, + ifm2_shape=ifm_max_shape, + ) ) # PASS 2 - Mul @@ -537,7 +562,9 @@ class SoftMax: shr13_op.add_input_tensor(mul_ofm) shr13_op.add_input_tensor(reciprocal_right_shift) shr13_op.set_output_tensor(ofm) - shr13_op.set_ifm_ofm_shapes() + shr13_op.ifm_shapes.append(Shape4D(mul_ofm.shape)) + shr13_op.ifm_shapes.append(Shape4D(reciprocal_right_shift.shape)) + shr13_op.ofm_shapes.append(Shape4D(mul_ofm.shape)) DebugDatabase.add_optimised(self.op, shr13_op) return shr13_op diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index fb877ca8..ef8a28fc 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -314,26 +314,6 @@ def create_const_tensor( return const_tensor -def create_reshape_tensor(tens, shape, ifm_reshape=True): - if shape == tens.shape: - return tens - # Tensors - name = tens.name + "_reshape" - reshape_ifm = tens - reshape_ofm = tens.clone("_reshaped") - reshape_ofm.set_all_shapes(shape) - if not ifm_reshape: - reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm - # Operator - reshape_op = Operation(Op.Reshape, name) - reshape_op.attrs["new_shape"] = shape - reshape_op.add_input_tensor(reshape_ifm) - reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) - reshape_op.set_output_tensor(reshape_ofm) - reshape_op.set_ifm_ofm_shapes() - return reshape_ofm if ifm_reshape else reshape_ifm - - # class that keeps track of all tensor addresses in the different memory types class TensorAddressMap: address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address)) @@ -443,6 +423,10 @@ class Tensor: def address(self, address: int): TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address) + @property + def is_standard_fm(self) -> bool: + return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap + def element_size(self) -> int: if self.element_size_bytes == 0: return self.dtype.size_in_bits() / 8 @@ -540,6 +524,15 @@ class Tensor: rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment) return rounded_size + def storage_size_for_shape(self, op_storage_shape: Shape) -> int: + elems = shape_num_elements(op_storage_shape) + elems = elems if elems else 0 + raw_size = elems * self.element_size() + if raw_size == 0: + raw_size = 1 # force it to take up space + rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment) + return rounded_size + def storage_size_for_sub_purpose( self, arch, sub_purpose: TensorSubPurpose, param_a: Optional[int] = None, param_b: Optional[int] = None ) -> int: @@ -614,7 +607,11 @@ class Tensor: def consumers(self) -> List[Operation]: return self.consumer_list - def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple: + def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D: + rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1) + return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum)) + + def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, op_shape4D: Shape4D) -> Tuple: # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] ) if self.storage_shape == []: @@ -622,12 +619,16 @@ class Tensor: 1, 1, 1, - [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None], + [self.address_for_coordinate(start_coord, op_shape4D=op_shape4D), None, None, None], ) - storage_shape_4D = full_shape(4, self.storage_shape, 1) - crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D[1]) - crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D[2]) + if self.is_standard_fm: + storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D) + else: + storage_shape_4D = Shape4D(self.storage_shape) + + crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height) + crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width) crossing_y = min(crossing_y, end_coord[1]) crossing_x = min(crossing_x, end_coord[2]) @@ -636,39 +637,41 @@ class Tensor: box_width = crossing_x - start_coord[2] addresses: List = [None] * 4 - addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list()) + addresses[0] = self.address_for_coordinate(start_coord, op_shape4D=op_shape4D) if end_coord[2] > crossing_x: addresses[1] = self.address_for_coordinate( - [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], start_coord[1], crossing_x, start_coord[3]], op_shape4D=op_shape4D ) raise UnsupportedFeatureError("Striping in vertical direction is not supported") if end_coord[1] > crossing_y: addresses[2] = self.address_for_coordinate( - [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], crossing_y, start_coord[2], start_coord[3]], op_shape4D=op_shape4D ) if end_coord[1] > crossing_y and end_coord[2] > crossing_x: addresses[3] = self.address_for_coordinate( - [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], crossing_y, crossing_x, start_coord[3]], op_shape4D=op_shape4D ) return box_height0, box_height0, box_width, addresses - def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, shape: Shape = None) -> int: - if shape is None: - shape = self.shape - offset = self.address_offset_for_coordinate(coord, shape, is_top_box) + def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, op_shape4D: Shape4D = None) -> int: + offset = self.address_offset_for_coordinate(coord, op_shape4D=op_shape4D, is_top_box=is_top_box) assert offset is not None return self.address + offset - def get_strides_and_coord(self, coord: Optional[Shape] = None) -> Tuple[Optional[Shape], Optional[Shape]]: + def get_strides_and_coord( + self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None + ) -> Tuple[Optional[Shape], Optional[Shape]]: if coord is None: coord = [0] * len(self.storage_shape) + if shape4D and self.is_standard_fm: + augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list() + else: + augmented_shape = full_shape(4, self.storage_shape, 1) + augmented_coord = coord - augmented_shape = self.storage_shape - while len(augmented_shape) < 4: - augmented_shape = [1] + augmented_shape while len(augmented_coord) < 4: augmented_coord = [0] + augmented_coord @@ -713,8 +716,8 @@ class Tensor: return strides, augmented_coord - def get_strides(self) -> Shape: - strides, _ = self.get_strides_and_coord() + def get_strides(self, shape4D: Optional[Shape4D] = None) -> Shape: + strides, _ = self.get_strides_and_coord(shape4D=shape4D) assert strides is not None return strides @@ -769,13 +772,13 @@ class Tensor: assert 0 <= index < len(self.compressed_values) return index == len(self.compressed_values) - 1 - def address_offset_for_coordinate(self, orig_coord: Shape, shape: Shape, is_top_box: bool = False) -> Optional[int]: + def address_offset_for_coordinate( + self, orig_coord: Shape, op_shape4D: Optional[Shape4D] = None, is_top_box: bool = False + ) -> Optional[int]: address_offset = 0 - coord = orig_coord - - coord = coord[-len(self.storage_shape) :] if self.sub_purpose == TensorSubPurpose.Standard: + shape = op_shape4D.as_list() if op_shape4D else self.shape for idx, c in enumerate(orig_coord): if is_top_box: assert c > 0 and c <= shape[idx] @@ -783,6 +786,7 @@ class Tensor: assert c >= 0 and c < shape[idx] if self.format == TensorFormat.WeightsCompressed: + storage_size = self.storage_size() if len(self.weight_compressed_offsets) == 0: return 0 @@ -814,13 +818,22 @@ class Tensor: assert index < len(self.weight_compressed_offsets) address_offset = self.weight_compressed_offsets[index] else: + coord = orig_coord + if op_shape4D and self.is_standard_fm: + storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list() + storage_size = self.storage_size_for_shape(storage_shape) + else: + storage_shape = self.storage_shape + coord = coord[-len(storage_shape) :] + storage_size = self.storage_size() + if is_top_box: coord = [c - 1 for c in coord] # handle wraparound for partial buffers. make sure to do this after subtracting top box: - coord = [c % self.storage_shape[idx] for idx, c in enumerate(coord)] + coord = [c % storage_shape[idx] for idx, c in enumerate(coord)] - strides, augmented_coord = self.get_strides_and_coord(coord) + strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D) if strides is None: return None @@ -830,7 +843,7 @@ class Tensor: address_offset += np.dot(augmented_coord, strides) assert address_offset >= 0 - assert address_offset <= self.storage_size() + assert address_offset <= storage_size return address_offset def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool: diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py index b3938bcc..b01b07c3 100644 --- a/ethosu/vela/test/test_graph_optimiser.py +++ b/ethosu/vela/test/test_graph_optimiser.py @@ -20,10 +20,12 @@ import numpy as np from ethosu.vela.data_type import DataType from ethosu.vela.graph_optimiser import convert_batched_fc_shape +from ethosu.vela.graph_optimiser import optimise_graph_a from ethosu.vela.graph_optimiser import optimise_pad from ethosu.vela.nn_graph import Graph from ethosu.vela.operation import Op from ethosu.vela.operation import Padding +from ethosu.vela.rewrite_graph import verify_graph_health from ethosu.vela.tensor import create_const_tensor from ethosu.vela.tensor import Shape4D from ethosu.vela.tensor import Tensor @@ -32,50 +34,49 @@ from ethosu.vela.test import testutil def test_convert_batched_fc(): """Tests shape conversion of batched fully connected""" - shape = [4, 8] - ifm = create_const_tensor("test_in", shape, np.uint8, np.zeros(shape)) - weights = create_const_tensor("weight_in", shape, np.uint8, np.zeros(shape)) + ifm_shape = [4, 8] + ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape)) + w_shape = [8, 4] + weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape)) ofm = Tensor(ifm.shape, np.uint8, "test_out") op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm) ifm.consumer_list.append(op) - op.ifm_shapes.append(Shape4D([4, 1, 1, 8])) - op.ofm_shapes.append(Shape4D([4, 1, 1, 8])) - prev_op = op.clone() - prev_op.ifm_shapes = op.ifm_shapes - prev_op.ofm_shapes = op.ofm_shapes + prev_op.ifm_shapes = op.ifm_shapes.copy() + prev_op.ofm_shapes = op.ofm_shapes.copy() conv_op = convert_batched_fc_shape(op, None, None) - assert conv_op.ifm != prev_op.ifm - assert conv_op.ofm != prev_op.ofm + assert conv_op.ifm == prev_op.ifm + assert conv_op.ofm == prev_op.ofm + assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8]) + assert op.ofm_shapes[0] == Shape4D([1, 2, 2, 8]) assert conv_op.type == Op.FullyConnected - assert len(conv_op.ifm.shape) == 4 + assert len(conv_op.ifm.shape) == 2 + assert len(conv_op.ofm.shape) == 2 assert conv_op.ifm.shape == conv_op.ofm.shape - assert conv_op.ifm.ops[0].type == Op.Reshape - shape = [1, 8] - ifm.shape = shape - weights.shape = shape - ofm.shape = shape + ifm.shape = [1, 8] + weights.shape = [8, 1] + ofm.shape = [1, 8] op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm) ifm.consumer_list.append(op) - op.ifm_shapes.append([1, 1, 1, 8]) - op.ofm_shapes.append([1, 1, 1, 8]) - prev_op = op.clone() - prev_op.ifm_shapes = op.ifm_shapes - prev_op.ofm_shapes = op.ofm_shapes + prev_op.ifm_shapes = op.ifm_shapes.copy() + prev_op.ofm_shapes = op.ofm_shapes.copy() conv_op = convert_batched_fc_shape(op, None, None) assert conv_op.ifm == prev_op.ifm assert conv_op.ofm == prev_op.ofm + assert op.ifm_shapes[0] == prev_op.ifm_shapes[0] + assert op.ofm_shapes[0] == prev_op.ofm_shapes[0] assert conv_op.type == Op.FullyConnected assert len(conv_op.ifm.shape) == 2 + assert len(conv_op.ofm.shape) == 2 assert conv_op.ifm.shape == conv_op.ofm.shape @@ -118,3 +119,91 @@ def test_optimise_pad(): assert op.attrs["explicit_padding"] == (2, 1, 1, 1) assert op.ifm.shape == [1, 76, 75, 64] assert pad_op not in op.ifm.ops + + +def test_remove_reshape(): + """ + Tests that the expected reshape are removed in graph_optimisation + """ + + def setup_network(): + quant = testutil.default_quant_params() + # create reshape1 op + ifm_shape = [64, 16] + reshape1_ofm_shape = [1, 4, 16, 16] + reshape1_ifm = create_const_tensor("reshape1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape)) + reshape1_ifm.quantization = quant + reshape1_ofm = create_const_tensor( + "reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape) + ) + reshape1_ofm.quantization = quant + shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape) + reshape1_op = testutil.create_op(Op.Reshape, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False) + reshape1_op.attrs["new_shape"] = reshape1_ofm_shape + reshape1_op.run_on_npu = True + + # create reshape2 op + reshape2_ofm_shape = [1, 8, 8, 16] + reshape2_ofm = create_const_tensor( + "reshape2_out", reshape2_ofm_shape, DataType.uint8, np.zeros(reshape2_ofm_shape) + ) + reshape2_ofm.quantization = quant + shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, reshape2_ofm_shape) + reshape2_op = testutil.create_op(Op.Reshape, [reshape1_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False) + reshape2_op.attrs["new_shape"] = reshape2_ofm_shape + reshape2_op.run_on_npu = True + + # create conv op + conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output") + conv_ofm.quantization = quant.clone() + weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights") + weight_tens.values = np.zeros(weight_tens.shape) + weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8) + weight_tens.quantization = quant.clone() + bias_tens = Tensor([16], DataType.int32, "biases") + + attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1} + attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1) + + conv2d_op = testutil.create_op( + Op.Conv2D, [reshape1_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False + ) + conv2d_op.run_on_npu = True + + # create reshape3 op + ofm_shape = [8, 8, 16] + reshape3_ofm = create_const_tensor("reshape3_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape)) + reshape3_ofm.quantization = quant + shape_tens = create_const_tensor("reshape3_shape", [1], DataType.int32, ofm_shape) + reshape3_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape3_ofm, set_ifm_ofm_shapes=False) + reshape3_op.attrs["new_shape"] = ofm_shape + reshape3_op.run_on_npu = True + nng = Graph() + sg = testutil.create_subgraph([reshape1_op, reshape2_op, conv2d_op, reshape3_op]) + nng.subgraphs.append(sg) + + return nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op + + # Test1 no Reshape op is expected to remain in the NPU subgrapgh + # but first one will be put on CPU + # Network is Reshape-Reshape-Conv-Reshape + # Result is cpu_Reshape-Conv + nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op = setup_network() + arch = testutil.create_arch() + assert verify_graph_health(nng) + nng = optimise_graph_a(nng, arch) + assert verify_graph_health(nng) + assert conv2d_op.ifm == reshape1_op.ofm + assert conv2d_op.ofm == reshape3_op.ofm + + # Test2 reshape2 with different quantisation, this Reshape op is expected to remain + # Network is Reshape-Reshape-Conv-Reshape + # expected is cpu_Reshape-Reshape-Conv + nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op = setup_network() + quant_zp32 = testutil.default_quant_params() + quant_zp32.zero_point = 32 + reshape2_op.ofm.quantization = quant_zp32 + assert verify_graph_health(nng) + nng = optimise_graph_a(nng, arch) + assert verify_graph_health(nng) + assert conv2d_op.ofm == reshape3_op.ofm diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py index 96aeb7eb..02e01a51 100644 --- a/ethosu/vela/test/testutil.py +++ b/ethosu/vela/test/testutil.py @@ -113,14 +113,15 @@ def create_op_with_quant_tensors( return op -def create_op(op_type, inputs, output, attrs=None): +def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True): op = Operation(op_type, output.name + "_op") for input in inputs: op.add_input_tensor(input) op.set_output_tensor(output) if attrs is not None: op.attrs = attrs - op.set_ifm_ofm_shapes() + if set_ifm_ofm_shapes: + op.set_ifm_ofm_shapes() return op -- cgit v1.2.1